# Generating the NaturalProofs ProofWiki domain

This notebook is used to create NaturalProofs's ProofWiki domain (`naturalproofs_proofwiki.json`).

ProofWiki provides a website dump here:
- https://proofwiki.org/wiki/User:Afirou/website_dump

Download the ProofWiki XML dump we used [here](https://drive.google.com/file/d/1pg6ae7xt-PO0ot4F_iJv9uhTLJ8Mr6Gi/view?usp=sharing).

In [None]:
!pip install bs4 wikitextparser nltk jsonlines

In [None]:
%load_ext autoreload
%autoreload 2

from tqdm import tqdm
from bs4 import BeautifulSoup as BS
from nltk import ngrams
import jsonlines
import pandas as pd
import os
import pickle
import torch
import wikitextparser as wtp
import os
import pandas as pd
import glob
import re
import xml.etree.ElementTree as ET
pd.set_option('display.max_colwidth', -1)
%pylab inline

#### Load and parse

In [9]:
filepath = './proof_wiki_nov_12_2020.xml'

soup = BS(open(filepath, 'r').read())

#### Parse redirects

We do this first so that we can store links using their redirected names.

In [5]:
redirects = {}

pages = soup.find_all('page')
pages = [page for page in pages if (
    (page.redirect is not None) and
    (not page.title.text.startswith('Talk:')) and
    (not page.title.text.startswith('User:')) and
    (not page.title.text.startswith('User talk:')) and
    (not page.title.text.startswith('Help:'))
)]

for page in tqdm(pages, total=len(pages)):
    redirects[page.title.text] = page.redirect['title']

print("%d redirects" % len(redirects))

100%|██████████| 7617/7617 [00:00<00:00, 34631.54it/s]

7617 redirects





#### Parse theorem title and theorem statement

In [6]:
from tqdm import tqdm

parsed = []
name_to_parsed = {}


pages = soup.find_all('page')
item_pages = [page for page in pages if (
    ("== Theorem ==" in page.text) and
    (not page.title.text.startswith('Talk:')) and
    (not page.title.text.startswith('User:')) and
    (not page.title.text.startswith('User talk:')) and
    (not page.title.text.startswith('Help:')) and
    (not page.redirect)
)]
exceptions = []
discarded = []
n = 0
for page in tqdm(item_pages, total=len(item_pages)):
    # title of the theorem
    theorem_title = page.title.text.replace('\u200e', '')

    # parse WikiMedia format
    text = page.text.replace('\u200e', '').replace('\u2062', '')
    wnode = wtp.parse(text)
    
    # get the WikiMedia Section that has "== Theorem ==" as its title
    theorem_sections = [s for s in wnode.sections if s.title is not None and s.title.strip() == 'Theorem']
    
    if len(theorem_sections) == 0:
        exceptions.append(page)
        continue
    else:
        theorem_section = theorem_sections[0]
    
    # get the content inside the <onlyinclude> tag, if there is one.
    tags = [t for t in theorem_section.tags() if t.name == 'onlyinclude']
    if len(tags) == 0:
        # remove subsections 
        contents = theorem_section.contents.strip().split('\n\n===')[0].strip()
        plain_text = theorem_section.plain_text().split("\n\n===")[0].strip()
        links = [l for l in theorem_section.wikilinks if l.title in contents]
    else:
        n += 1
        contents = tags[0].contents.strip()
        plain_text = tags[0].plain_text().strip()
        links = tags[0].wikilinks
        
    if plain_text == '== Theorem ==':
        contents = ""
        links = []

    links = [l.title for l in links]
    links = [redirects.get(l, l) for l in links]
    
    categories = [node.title.split('Category:')[1] for node in wnode.wikilinks if node.title.startswith('Category:')]
    
    data = {
        'type': 'theorem',
        'title': theorem_title,
        'contents': contents,
        'full_contents': text,
        'links': links,
        'has_contents': contents != '',
        'categories': categories,
    }
    parsed.append(data)
    name_to_parsed[theorem_title] = data
    
print("%d parsed, %d exceptions." % (len(parsed), len(exceptions)))
print("has theorem content: %d" % (len([x for x in parsed if x['has_contents']])))

100%|██████████| 20203/20203 [00:24<00:00, 828.52it/s] 

19734 parsed, 469 exceptions.
has theorem content: 16473





#### Definitions

In [8]:
parsed_def = []
name_to_parsed_def = {}


pages = soup.find_all('page')
item_pages = [page for page in pages if (
    ("Definition:" in page.title.text) and
    (not page.title.text.startswith('Talk:')) and
    (not page.title.text.startswith('User:')) and
    (not page.title.text.startswith('User talk:')) and
    (not page.title.text.startswith('Help:')) and 
    (not page.redirect) and
    (not page.title.text in name_to_parsed)  # some 'definitions' are actually theorem pages, e.g. Definition:Stabilizer
)]
discarded = []
for page in tqdm(item_pages, total=len(item_pages)):
    title = page.title.text.replace('\u200e', '').replace('\u2062', '')
    
    # parse WikiMedia format
    text = page.text.replace('\u200e', '').replace('\u2062', '')
    wnode = wtp.parse(text)
    
    sections = [s for s in wnode.sections if s.title is not None and s.title.strip() == 'Definition']

    if len(sections) == 0:
        contents = ""
        links = []
    else:
        section = sections[0]

        # get the content inside the <onlyinclude> tag, if there is one.
        tags = [t for t in section.tags() if t.name == 'onlyinclude']
        if len(tags) == 0:
            contents = section.contents.strip()
            links = section.wikilinks
            links = [l.title for l in links]
        else:
            contents = tags[0].contents.strip()
            links = tags[0].wikilinks
            links = [l.title for l in links]
    
    links = [redirects.get(l, l) for l in links]
    categories = [node.title.split('Category:')[1] for node in wnode.wikilinks if node.title.startswith('Category:')]
   
    data = {
        'type': 'definition',
        'title': title,
        'contents': contents,
        'full_contents': text,
        'links': links,
        'has_contents': contents != "",
        'categories': categories
    }
    parsed_def.append(data)
    name_to_parsed_def[title] = data
    

print("%d parsed" % (len(parsed_def)))
print("has definition content: %d" % (len([x for x in parsed_def if x['has_contents']])))

100%|██████████| 12420/12420 [00:07<00:00, 1576.04it/s]

12420 parsed
has definition content: 9982





#### Proofs

In [10]:
parsed_proof = []
name_to_parsed_proof = {}

pages = soup.find_all('page')
item_pages = [page for page in pages if (
    (not page.title.text.startswith('Talk:')) and
    (not page.title.text.startswith('User:')) and
    (not page.title.text.startswith('User talk:')) and
    (not page.title.text.startswith('Help:'))
)]
exceptions = []
discarded = []
missing_links = []

for page in tqdm(item_pages, total=len(item_pages)):
    title = page.title.text.strip('\u200e')
        
    # parse WikiMedia format
    text = page.text.replace('\u200e', '').replace('\u2062', '')
    wnode = wtp.parse(text)  
    
    # get the WikiMedia Section that has "== Proof ==" as its title
    sections = [s for s in wnode.sections if s.title is not None and s.title.strip() == 'Proof']

    if len(sections) == 0:
        exceptions.append(page)
        continue
    else:
        section = sections[0]

    # get the content inside the <onlyinclude> tag, if there is one.
    tags = [t for t in section.tags() if t.name == 'onlyinclude']
    if len(tags) == 0:
        contents = section.contents.strip()
        links = section.wikilinks
    else:
        contents = tags[0].contents.strip()
        links = tags[0].wikilinks

    links = [l.title.strip('\u200e') for l in links]
    links = [redirects.get(l, l) for l in links]

        
    for ltitle in links:
        if ltitle not in name_to_parsed and ltitle not in name_to_parsed_def:
            if ltitle.startswith('Category:') or ltitle.startswith('File:'):
                continue
            missing_links.append(ltitle)
    
    # skip proofs without a proof
    if contents == '{{ProofWanted}}' or contents == '{{proof wanted}}' or contents == '{{finish}}' or contents == '{{Finish}}':
        continue
        
    categories = [node.title.split('Category:')[1] for node in wnode.wikilinks if node.title.startswith('Category:')]

    data = {
        'type': 'proof',
        'title': title,
        'contents': contents,
        'links': links,
        'categories': categories
    }
    parsed_proof.append(data)
    name_to_parsed_proof[title] = data
    

print(len(missing_links), len(set(missing_links)))
print("%d parsed, %d exceptions." % (len(parsed_proof), len(exceptions)))

100%|██████████| 67077/67077 [00:33<00:00, 1973.51it/s]

4988 2019
19956 parsed, 46454 exceptions.





#### Parse pages that are linked to that we missed

- Axioms
- Proof techniques
- Corollaries
- ...

In [11]:
missing_links_set = set(missing_links)
pages = soup.find_all('page')
pages = [page for page in pages if page.title.text in missing_links_set]

parsed_extra = []
name_to_parsed_extra = {}

for page in tqdm(pages, total=len(pages)):
    # title of the theorem
    title = page.title.text.strip('\u200e')
    
    # parse WikiMedia format
    text = page.text.replace('\u200e', '').replace('\u2062', '')
    wnode = wtp.parse(text)  
    
    links = wnode.wikilinks if wnode.wikilinks is not None else []
    links = [l.title.strip('\u200e') for l in links]
    links = [redirects.get(l, l) for l in links]

    categories = [node.title.split('Category:')[1] for node in wnode.wikilinks if node.title.startswith('Category:')]

    data = {
        'type': 'extra',
        'title': title,
        'contents': text,
        'full_contents': text,
        'has_contents': True,
        'links': links,
        'categories': categories
    }
    parsed_extra.append(data)
    name_to_parsed_extra[title] = data
    

print("%d parsed." % (len(parsed_extra)))

100%|██████████| 1006/1006 [00:00<00:00, 2127.86it/s]

1006 parsed.





#### Next, we remove the remaining missing links

In [12]:
link_set = set()
for name in name_to_parsed:
    link_set.add(name)

for name in name_to_parsed_def:
    link_set.add(name)
    
for name in name_to_parsed_extra:
    link_set.add(name)
    
print(len(link_set))

33160


In [13]:
for name, item in name_to_parsed.items():
    links = item['links']
    links = [l for l in links if l in link_set]
    item['links'] = links

for name, item in name_to_parsed_def.items():
    links = item['links']
    links = [l for l in links if l in link_set]
    item['links'] = links
    
for name, item in name_to_parsed_proof.items():
    links = item['links']
    links = [l for l in links if l in link_set]
    item['links'] = links
    
for name, item in name_to_parsed_extra.items():
    links = item['links']
    links = [l for l in links if l in link_set]
    item['links'] = links


#### Add a flag to denote whether a theorem has a proof

In [16]:
theorem_no_proof = set()

for name in tqdm(name_to_parsed, total=len(name_to_parsed)):
    theorem = name_to_parsed[name]
    proof_names = []
    for proof_name in name_to_parsed_proof:
        if name == proof_name:
            proof_names.append(proof_name)
        elif name in proof_name and '/' in proof_name and len(proof_name.split('/')) == 2:
            suffix = proof_name.split('/')[-1]
            if 'proof' in suffix.lower():
                proof_names.append(proof_name)
        else:  # NOTE: we discard some proofs this way
            pass
    theorem['proof_names'] = proof_names
    theorem['has_proof'] = len(proof_names) > 0


100%|██████████| 19734/19734 [01:22<00:00, 239.83it/s]


### Produce dataset

#### Utilities

In [18]:
def replace_links(lines):
    def __replace(line):
        matches = re.findall(r'(\[\[([^]]*)\]\])', line)
        for match in matches:
            full, inner = match
            splt = inner.split('|')
            if len(splt) == 1:
                txt = splt[0]
            elif len(splt) == 2:
                txt = splt[1]
            else:
                txt = ''.join(splt[1:])
            if full in line:
                line = line.replace(full, txt)
        return line
    lines_ = [
        __replace(line) for line in lines
    ]
    return lines_

Theorems with proofs.

In [19]:
examples_json = {
    'examples': [],
    'theorems': [],
    'definitions': [],
    'other': [],
    'proofs': []
}

# `examples`: contains theorems that have contents and at least one proof with at least one reference 
for theorem_name in name_to_parsed:
    theorem = name_to_parsed[theorem_name]
    
    if not theorem['has_proof'] or not theorem['has_contents']:
        continue
    
    example = {}
    example['type'] = 'theorem'
    example['has_proof'] = theorem['has_proof']
    example['title'] = theorem['title']
    example['proof_titles'] = theorem['proof_names']
    example['categories'] = theorem['categories']
    example['statement'] = {
        'contents': [line for line in theorem['contents'].split('\n') if line != ''],
        'refs': theorem['links'],
    }
    example['statement']['read_contents'] = replace_links(example['statement']['contents'])

    nrefs = 0
    example['proofs'] = []    
    for proof_name in theorem['proof_names']:
        proof_ = name_to_parsed_proof[proof_name]
        proof = {
            'title': proof_name,
            'refs': name_to_parsed_proof[proof_name]['links']
        }
        example['proofs'].append(proof)
        
        nrefs += len(proof['refs'])
    
    # only keep if there is at least one reference
    if nrefs == 0:
        continue
    
    if len(example['proofs']) == 0:
        print(theorem_name)
        
    examples_json['examples'].append(example)

    
# store _all_ theorems (including `full_contents` as well) separately
for name in name_to_parsed:
    item = name_to_parsed[name]
    example = {}
    example['type'] = 'theorem'
    example['has_proof'] = item['has_proof']
    example['has_contents'] = item['has_contents']
    example['title'] = item['title']
    example['proof_titles'] = item['proof_names']
    example['contents'] = [line for line in item['contents'].split('\n') if line != '']
    example['read_contents'] = replace_links(example['contents'])
    example['refs'] = item['links']
    example['categories'] = item['categories']
    examples_json['theorems'].append(example)
    
# store all proofs separately
for name in name_to_parsed_proof:
    item = name_to_parsed_proof[name]
    example = {
        'type': 'proof',
        'title': item['title'],
        'contents': [line for line in item['contents'].split('\n') if line != ''],
        'refs': item['links'],
        'categories': item['categories']
    }
    example['read_contents'] = replace_links(example['contents'])
    examples_json['proofs'].append(example)

# store all definitions separately
for name in name_to_parsed_def:
    item = name_to_parsed_def[name]
    example = {
        'type': 'definition',
        'title': item['title'],
        'has_contents': item['has_contents'],
        'contents': [line for line in item['contents'].split('\n') if line != ''],
        'refs': item['links'],
        'categories': item['categories']
    }
    example['read_contents'] = replace_links(example['contents'])
    examples_json['definitions'].append(example)

# store all additional pages that are linked to
for name in name_to_parsed_extra:
    item = name_to_parsed_extra[name]
    example = {
        'type': 'other',
        'title': item['title'],
        'has_contents': item['has_contents'],
        'contents': [line for line in item['contents'].split('\n') if line != ''],
        'refs': item['links'],
        'categories': item['categories']
    }
    example['read_contents'] = replace_links(example['contents'])
    examples_json['other'].append(example)


#### Remove duplicate examples
- $(x_1,y_1),(x_2,y_2)$ such that the contents of $y_1$ exactly matches the contents of $y_2$.

In [20]:
from collections import defaultdict
duplicates = defaultdict(list)
for example in tqdm(examples_json['examples']):
    for example2 in examples_json['examples']:
        if example['title'] == example2['title']:
            continue
        
        for proof1_title in example['proof_titles']:
            for proof2_title in example2['proof_titles']:
                proof1 = name_to_parsed_proof[proof1_title]
                proof2 = name_to_parsed_proof[proof2_title]
                if proof1['contents'] == proof2['contents']:
                    duplicates[example['title']].append(example2['title'])
                    
print(len(duplicates))

100%|██████████| 13859/13859 [05:17<00:00, 43.59it/s]

328





In [21]:
removed = set()
to_remove = []
for title in sorted(duplicates.keys()):
    if title in removed:
        continue
    
    for dup_title in duplicates[title]:
        if dup_title not in removed:
            to_remove.append(dup_title)
            removed.add(dup_title)

len(to_remove)

199

In [22]:
examples = [x for x in examples_json['examples'] if x['title'] not in to_remove]
examples_json['examples'] = examples

Theorem statement is the same:

In [23]:
from collections import defaultdict
duplicates = defaultdict(list)
for example in tqdm(examples_json['examples']):
    for example2 in examples_json['examples']:
        if example['title'] == example2['title']:
            continue
        
        if ''.join(example['statement']['contents']) == ''.join(example2['statement']['contents']):
            duplicates[example['title']].append(example2['title'])
                    
print(len(duplicates))

100%|██████████| 13660/13660 [02:57<00:00, 76.92it/s]

107





In [24]:
removed = set()
to_remove = []
for title in sorted(duplicates.keys()):
    if title in removed:
        continue
    
    for dup_title in duplicates[title]:
        if dup_title not in removed:
            to_remove.append(dup_title)
            removed.add(dup_title)

len(to_remove)

63

In [25]:
examples = [x for x in examples_json['examples'] if x['title'] not in to_remove]
examples_json['examples'] = examples

Finally, assign each reference (theorem/definition/other) a unique id, and include the reference ids in the examples.

In [26]:
name_to_id = {}
proof_name_to_id = {}
retrieval_set = []
for item in examples_json['theorems']:
    if item['title'] not in name_to_id:
        name_to_id[item['title']] = len(name_to_id)
        item['id'] = name_to_id[item['title']]
    else:
        print(item['title'])
        
for item in examples_json['definitions']:
    if item['title'] not in name_to_id:
        name_to_id[item['title']] = len(name_to_id)
        item['id'] = name_to_id[item['title']]
    else:
        print(item['title'])
        
for item in examples_json['other']:
    if item['title'] not in name_to_id:
        name_to_id[item['title']] = len(name_to_id)
        item['id'] = name_to_id[item['title']]
    else:
        print(item['title'])

for i, item in enumerate(examples_json['proofs']):
    name = 'proof_'+item['title']
    if name not in name_to_id:
        proof_name_to_id[name] = i
        item['proof_id'] = proof_name_to_id[name]
    else:
        print(name)
        
for i, example in enumerate(examples_json['examples']):
    example['example_id'] = i
    example['theorem_id'] = name_to_id[example['title']]
    # references in statement
    ref_ids = [name_to_id[ref] for ref in example['statement']['refs']]
    example['statement']['ref_ids'] = ref_ids
    
    # references in proofs
    for proof in example['proofs']:
        ref_ids = [name_to_id[ref] for ref in proof['refs']]
        proof['ref_ids'] = ref_ids
        proof['proof_id'] = proof_name_to_id['proof_'+proof['title']]

Rename `examples` as `retrieval_examples`

In [27]:
examples_json['retrieval_examples'] = examples_json.pop('examples')

In [28]:
for k, vs in examples_json.items():
    print("%s\t%d" % (k, len(vs)))

theorems	19734
definitions	12420
other	1006
proofs	19956
retrieval_examples	13597


In [29]:
dataset = {
    'dataset': examples_json,
}


import json
output_json = './naturalproofs_proofwiki.json'
with open(output_json, 'w') as f:
    json.dump(dataset, f)

# Dataset splits

Form dataset splits using the reference graph.

In [None]:
%pylab inline
import pandas as pd
import glob
import pickle
import os
import json
from pprint import pprint as pp
from collections import defaultdict
import torch
from tqdm import tqdm
import networkx

In [2]:
raw_ds = json.load(open(output_json, 'r'))

### Form the reference graph

$G=(V,E)$
- $v\in V$: theorem, definition, other page
- $(u, v)\in E$: $u$ occurs in the statement or a proof of $v$

In [3]:
refs = raw_ds['dataset']['theorems'] + raw_ds['dataset']['definitions'] + raw_ds['dataset']['other']

graph = defaultdict(list)

id2ref = {}
ref2id = {}
for r in refs:
    ref2id[r['title']] = r['id']
    id2ref[r['id']] = r
    
title2proof = {}
for p in raw_ds['dataset']['proofs']:
    title2proof[p['title']] = p
    
pairs = []
cycles = []
for r1 in refs:
    
    # Make an edge for each reference in the _statement_
    for r2 in r1['refs']:
        
        r1id = r1['id']
        r2id = ref2id[r2]
        
        if r1id != r2id:
            graph[r2id].append(r1id)
            
            pairs.append((r2id, r1id))
            
            if r2id in graph[r1id]:
                cycles.append(tuple(sorted((r2id, r1id))))
                
    # Make an edge for each reference in the _proof_ (when available)
    if r1['type'] == 'theorem' and r1['has_proof']:
        for title in r1['proof_titles']:
            proof = title2proof[title]
            
            for r2 in proof['refs']:                
                r1id = r1['id']
                r2id = ref2id[r2]
                if r1id != r2id:
                    graph[r2id].append(r1id)
                    
                    pairs.append((r2id, r1id))

                    if r2id in graph[r1id]:
                        cycles.append(tuple(sorted((r2id, r1id))))

cycles = set(cycles)
print("%d 1-cycles" % (len(cycles)))

877 1-cycles


In [4]:
G = networkx.DiGraph(graph)
leafs = [node for node in G.nodes() if G.in_degree(node) != 0 and G.out_degree(node)==0]
nonleafs = [node for node in G.nodes() if G.in_degree(node) == 0 or G.out_degree(node) != 0]
heads = [node for node in G.nodes() if G.in_degree(node) == 0 and G.out_degree(node) > 0]

print("%d nodes\n%d leaf\n%d non-leaf\n\n%d heads" % (
    len(G.nodes()),
    len(leafs),
    len(nonleafs),
    len(heads)
))

30681 nodes
12396 leaf
18285 non-leaf

1392 heads


#### BFS layers

Form BFS layers, count the number of nodes, example-worthy theorems (has proof(s) + contents), and 1-cycles.

In [6]:
print("total nodes %d\n" % len(G.nodes()))

incycle = set()
for a, b in cycles:
    incycle.add(a)
    incycle.add(b)

# theorems that correspond to examples (e.g. has a proof, contents)
tid2eid = {}
for item in raw_ds['dataset']['retrieval_examples']:
    tid2eid[item['theorem_id']] = item['example_id']
    
layers = defaultdict(set)
nleafs = []

seen = set()
for node in heads:
    layers[0].add(node)
    seen.add(node)
    
layer = 0
print('layer', 'nodes', 'thms', 'leaf_thms', sep='\t')

while len(layers[layer]) > 0:
    thms = [x for x in layers[layer] if x in tid2eid]
    leaf_thms = [x for x in layers[layer] if x in tid2eid
        and x in leafs
        and (x not in incycle)
    ]
    nleafs.append(len(leaf_thms))
    
    print(layer, len(layers[layer]), len(thms), len(leaf_thms), sep='\t')
    for node in layers[layer]:
        for child in G.successors(node):
            if child not in seen:
                layers[layer+1].add(child)
                seen.add(child)
    layer += 1

nleafs.append(0)
    
nleafs = np.array(nleafs)

total nodes 30681

layer	nodes	thms	leaf_thms
0	1392	0	0
1	10850	5777	2376
2	11713	5644	2581
3	5239	1898	788
4	1322	252	107
5	141	25	12
6	14	1	0


#### Define the train, valid, test splits

We define valid $\cup$ test as leaves, selected at each layer proportional to the number of leaves at the layer.

In [7]:
budget = 2200
leaf_frac = nleafs/nleafs.sum()

In [8]:
rand = np.random.RandomState(42)

splits = defaultdict(set)

for layer in range(len(layers)):
        
    # get number of eval leaves for this layer
    nleaf = int(budget*leaf_frac[layer])
    
    # randomly sample `nleaf` leaf theorems
    leaf_thms = [x for x in layers[layer] if x in tid2eid
        and x in leafs
        and (x not in incycle)
    ]
    perm = rand.permutation(len(leaf_thms))
    eval_thms = [leaf_thms[i] for i in perm[:nleaf]]
    
    # collect as evaluation theorems and references
    for x in eval_thms:
        splits['eval_thms'].add(x)
        splits['eval_refs'].add(x)
    
    # collect all other items as training data
    eval_thms_set = set(eval_thms)
    for x in layers[layer]:
        if x not in eval_thms_set:
            splits['train_refs'].add(x)
            if x in tid2eid:
                splits['train_thms'].add(x)
                
for k in splits:
    splits[k] = list(splits[k])
    print(k, len(splits[k]))

train_refs 28473
eval_thms 2198
eval_refs 2198
train_thms 11399


#### Verify that evaluation theorems are not referred in training

In [9]:
for x in tqdm(splits['eval_thms'], total=len(splits['eval_thms'])):
    for y in splits['train_refs']:
        if G.has_predecessor(y, x):
            print(id2ref[x]['title'], id2ref[y]['title'])

100%|██████████| 2198/2198 [00:42<00:00, 51.90it/s]


#### Randomly split evaluation into validation and test.

In [10]:
rand = np.random.RandomState(42)
perm = rand.permutation(len(splits['eval_thms']))

idx = len(splits['eval_thms'])//2
val_idxs = perm[:idx]

val_thms = [splits['eval_thms'][i] for i in perm[:idx]]
tst_thms = [splits['eval_thms'][i] for i in perm[idx:]]


#### Convert theorem ids to example ids

In [12]:
tid2eid = {}
for item in raw_ds['dataset']['retrieval_examples']:
    tid2eid[item['theorem_id']] = item['example_id']

In [13]:
final_splits = {
    'train': {},
    'valid': {},
    'test': {}
}

final_splits['train']['ref_ids'] = splits['train_refs']
final_splits['train']['example_ids'] = [tid2eid[t] for t in splits['train_thms']]

final_splits['valid']['ref_ids'] = splits['train_refs'] + splits['eval_refs']
final_splits['valid']['example_ids'] = [tid2eid[t] for t in val_thms]

final_splits['test']['ref_ids'] = splits['train_refs'] + splits['eval_refs']
final_splits['test']['example_ids'] = [tid2eid[t] for t in tst_thms]

In [14]:
for k in final_splits:
    print(k)
    for k2 in final_splits[k]:
        print(k2, len(final_splits[k][k2]))
    print()

train
ref_ids 28473
example_ids 11399

valid
ref_ids 30671
example_ids 1099

test
ref_ids 30671
example_ids 1099



In [15]:
raw_ds['splits'] = final_splits

import json
with open(output_json, 'w') as f:
    json.dump(raw_ds, f)

## Normalize to NaturalProofs data schema

This step converts the data to adhere to the NaturalProofs data schema.

In [None]:
import json

with open(output_json) as f:
    js = json.load(f)

print(js.keys())
print(js['dataset'].keys())

theorems = js['dataset']['theorems']
definitions = js['dataset']['definitions']
others = js['dataset']['other']
proofs = js['dataset']['proofs']
retrieval_examples = js['dataset']['retrieval_examples']
splits = js['splits']

new_theorems = []
new_definitions = []
new_others = []
new_retrieval_examples = []
new_splits = {}

title2id = {}
for item in theorems + definitions + others:
    title2id[item['title']] = item['id']

print(proofs[0].keys())
title2proof = {}
for proof in proofs:
    title2proof[proof['title']] = {
        'contents': proof['contents'],
        'refs': proof['refs'],
        'ref_ids': [title2id[title] for title in proof['refs']],
    }

print(theorems[0].keys())
for item in theorems:
    new_theorems.append({
        'id': item['id'],
        'type': item['type'],
        'label': item['title'],
        'title': item['title'],
        'categories': item['categories'],
        'contents': item['contents'],
        'refs': item['refs'],
        'ref_ids': [title2id[title] for title in item['refs']],
        'proofs': [title2proof[t] for t in item['proof_titles']],
    })

for item in definitions:
    new_definitions.append({
        'id': item['id'],
        'type': item['type'],
        'label': item['title'],
        'title': item['title'],
        'categories': item['categories'],
        'contents': item['contents'],
        'refs': item['refs'],
        'ref_ids': [title2id[title] for title in item['refs']],
        'proofs': [],
    })

for item in others:
    new_others.append({
        'id': item['id'],
        'type': item['type'],
        'label': item['title'],
        'title': item['title'],
        'categories': item['categories'],
        'contents': item['contents'],
        'refs': item['refs'],
        'ref_ids': [title2id[title] for title in item['refs']],
        'proofs': [],
    })

id2item = {}
for item in new_theorems + new_definitions + new_others:
    id2item[item['id']] = item

eid2tid = {}
for e in retrieval_examples:
    eid2tid[e['example_id']] = e['theorem_id']
    
print(retrieval_examples[0].keys())
new_retrieval_examples = [e['theorem_id'] for e in retrieval_examples]

for split in ['train', 'valid', 'test']:
    new_splits[split] = {
        'ref_ids': splits[split]['ref_ids'],
        'examples': sum([[(eid2tid[eid], j) for j in range(len(id2item[eid2tid[eid]]['proofs'])) \
                                            if len(id2item[eid2tid[eid]]['proofs'][j]['refs']) > 0] \
                         for eid in splits[split]['example_ids']], []),
    }

js = {
    'dataset': {
        'theorems': new_theorems,
        'definitions': new_definitions,
        'others': new_others,
        'retrieval_examples': new_retrieval_examples,
    },
    'splits': new_splits,
}

with open(output_json, 'w') as f:
    json.dump(js, f, indent=4)

### Category Graph

This step adds extra category tags to the dataset.

In [None]:
import requests

cats = {}

prefix = 'https://www.proofwiki.org/'
root = '/wiki/Category:Content_Categories'

def dfs(title, suffix, parent):
    print(str(len(cats)) + ' ' + title)
    url = prefix + suffix
    page = requests.get(url)
    soup = BS(page.content, 'html.parser')
    
    children = []
    children_titles = []
    
    for item in soup.find_all('div', class_='CategoryTreeItem'):
        a = item.find('a')
        children.append((a.text, a['href']))
        children_titles.append(a.text)
    
    cat = {
        'title': title,
        'suffix': suffix,
        'parent': parent,
        'children': children_titles,
    }
    cats[title] = cat
    
    for (child_title, child_suffix) in children:
        if child_title not in cats:
            dfs(child_title, child_suffix, title)

dfs('Content Categories', root, None)

In [41]:
for cat in cats:
    cats[cat]['parents'] = []
for cat in cats:
    for c in cats[cat]['children']:
        cats[c]['parents'].append(cat)

In [None]:
cat2kind = {}
def dfs2(cat, kind):
    if cat in cat2kind and cat2kind[cat] != kind:
        print(cat)
    cat2kind[cat] = kind
    visited.add(cat)
    for child in cats[cat]['children']:
        if child not in visited:
            dfs2(child, kind)

visited = set()
dfs2('Proofs by Topic', 'proof')
visited = set()
dfs2('Definitions/Branches of Mathematics', 'definition')

In [57]:
with open(output_json) as f:
    js = json.load(f)

theorems = js['dataset']['theorems']

In [58]:
categories = set()
for item in theorems + definitions + others:
    for cat in item['categories']:
        categories.add(cat)
print('%d cats appear in dataset' % len(categories))

missing = 0
for cat in categories:
    if cat.strip(' ') not in cats:
        all_numbers = True
        for c in cat:
            if not (c == ',' or '0' <= c <= '9'):
                all_numbers = False
        if all_numbers:
            continue
        missing += 1
print('%d appeared cats missing from cat graph' % missing)

5339 cats appear in dataset
49 appeared cats missing from cat graph


In [127]:
def get_toplevel(cat, stack=[]):
    if cat is None:
        return set()
    if cat in stack:
        return set()
    ans = set()
    for p in cats[cat]['parents']:
        if p == 'Proofs by Topic':
            ans.add(cat)
        else:
            res = get_toplevel(p, stack + [cat])
            for k in res:
                ans.add(k)
    return ans

def get_recursive(cat, stack=[]):
    if cat is None:
        return set()
    if cat in stack:
        return set()
    for p in cats[cat]['parents']:
        if p == 'Proofs by Topic':
            recs.add(cat)
            for k in stack:
                recs.add(k)
        else:
            get_recursive(p, stack + [cat])

for item in theorems:
    tls = set()
    recs = set()
    for cat in item['categories']:
        if cat not in cats: continue
        res = get_toplevel(cat)
        for k in res:
            tls.add(k)
        get_recursive(cat)
    item['toplevel_categories'] = list(tls)
    item['recursive_categories'] = list(recs)

In [128]:
js['dataset']['theorems'] = theorems

with open(output_json, 'w') as f:
    json.dump(js, f, indent=4)