# Generating the NaturalProofs Stacks domain

This notebook is used to create NaturalProofs's ProoStacksfWiki domain (`naturalproofs_stacks.json`).

First, pull the [Stacks github](https://github.com/stacks/stacks-project) (we used commit 4df67b8).

In [19]:
import glob
import os
import json
import re
from collections import defaultdict
from tqdm import tqdm
import numpy as np

In [20]:
files_ = glob.glob('./stacks-project/*.tex')
files = []
for f in files_:
    if 'coding.tex' in f:
        continue
    files.append(f)

stems = [os.path.basename(f).split('.tex')[0] for f in files]

In [21]:
all_types = set()
for f in files:
    tex = open(f).read()
    stem = os.path.basename(f).split('.tex')[0]

    labels_ = re.findall(r'\\label{([a-z|A-Z|0-9|\-]+)}', tex)
    for l in labels_:
        all_types.add(l.split('-')[0])

all_types

{'definition',
 'equation',
 'example',
 'exercise',
 'item',
 'lemma',
 'proposition',
 'remark',
 'remarks',
 'section',
 'situation',
 'subsection',
 'theorem'}

### Parse dataset

In [24]:
def extract_refs(s):
    refs = re.findall(r'\\ref{([^}]*)}', s)
    refs = [ref for ref in refs if all([t not in exclude_kinds for t in ref.split('-')])]
    for i in range(len(refs)):
        add_stem = True
        for stem_ in stems:
            if refs[i].startswith(stem_):
                add_stem = False
        if add_stem:
            refs[i] = '%s-%s' % (stem, refs[i])
    return refs

def parse_proof(statement):
    contents = statement.strip().split('\n')
    contents = list(filter(lambda s: s != '', contents))
    refs = extract_refs(proof)
    
    return {
        'contents': contents,
        'refs': refs,
    }

def parse_item(statement):
    lines = statement.strip().split('\n')
    start = 0
    label = None
    for i, line in enumerate(lines):
        if '\\label' in line:
            label = re.findall(r'\\label{([^}]*)}', line)[0]
            start = i+1
            break
    if label is None:
        raise ValueError('no label')
    label = '%s-%s' % (stem, label)
    contents = lines[start:]
    contents = list(filter(lambda s: s != '', contents))
    refs = extract_refs(statement)

    return {
        'label': label,
        'categories': [stem],
        'title': label,
        'contents': contents,
        'refs': refs,
    }

In [25]:
theorem_kinds = ['theorem', 'lemma', 'proposition']
definition_kinds = ['definition']
other_kinds = ['remark', 'remarks']
all_ref_kinds = theorem_kinds + definition_kinds + other_kinds
exclude_kinds = [t for t in all_types if t not in all_ref_kinds]

kind2type = {}
for kind in theorem_kinds:
    kind2type[kind] = 'theorem'
for kind in definition_kinds:
    kind2type[kind] = 'definition'
for kind in other_kinds:
    kind2type[kind] = 'other'

theorems = []
definitions = []
others = []
label2id = {}
cnt = 0

for f in files:
    tex = open(f).read()
    stem = os.path.basename(f).split('.tex')[0]

    for kind in all_ref_kinds:
        splits = tex.split('\\begin{%s}' % kind)[1:]
        for split in splits:
            item = {
                'id': cnt,
                'type': kind2type[kind],
            }
            cnt += 1
            
            statement, other = split.split('\\end{%s}' % kind)
            item.update(parse_item(statement))

            if kind in theorem_kinds:
                proof = other.split('\\end{proof}')[0]
                proof = re.findall(r'\\begin{proof}(.*)', proof, re.DOTALL)
                assert len(proof) == 1
                proof = proof[0]
                proof = parse_proof(proof)
                item['proofs'] = [proof]

                theorems.append(item)

            elif kind in definition_kinds:
                definitions.append(item)

            elif kind in other_kinds:
                others.append(item)

            label2id[item['label']] = item['id']

#### Add `ref_ids`

In [26]:
for item in theorems:
    item['ref_ids'] = [label2id[label] for label in item['refs']]
    for proof in item['proofs']:
        proof['ref_ids'] = [label2id[label] for label in proof['refs']]
for item in definitions:
    item['ref_ids'] = [label2id[label] for label in item['refs']]
for item in others:
    item['ref_ids'] = [label2id[label] for label in item['refs']]

In [27]:
retrieval_examples = [thm['id'] for thm in theorems if len(thm['proofs']) > 0 and len(thm['proofs'][0]['refs']) > 0]

dataset = {
    'theorems': theorems,
    'definitions': definitions,
    'others': others,
    'retrieval_examples': retrieval_examples,
}

#### Split dataset

In [28]:
refs = theorems + definitions + others

id2ref = {}
label2id = {}
for r in refs:
    id2ref[r['id']] = r
    label2id[r['label']] = r['id']
    
graph = defaultdict(list)
pairs = []
cycles = []

for r1 in refs:
    
    # Make an edge for each reference in the _statement_
    for r2 in r1['refs']:
        
        r1id = r1['id']
        r2id = label2id[r2]
        
        if r1id != r2id:
            graph[r2id].append(r1id)
            
            pairs.append((r2id, r1id))
            
            if r2id in graph[r1id]:
                cycles.append(tuple(sorted((r2id, r1id))))
                
    # Make an edge for each reference in the _proof_ (when available)
    if r1['type'] == 'theorem':
        for proof in r1['proofs']:
            
            for r2 in proof['refs']:                
                r1id = r1['id']
                r2id = label2id[r2]
                if r1id != r2id:
                    graph[r2id].append(r1id)
                    
                    pairs.append((r2id, r1id))

                    if r2id in graph[r1id]:
                        cycles.append(tuple(sorted((r2id, r1id))))

cycles = set(cycles)
print("%d 1-cycles" % (len(cycles)))

16 1-cycles


In [29]:
import networkx

G = networkx.DiGraph(graph)
leafs = [node for node in G.nodes() if G.in_degree(node) != 0 and G.out_degree(node)==0]
nonleafs = [node for node in G.nodes() if G.in_degree(node) == 0 or G.out_degree(node) != 0]
heads = [node for node in G.nodes() if G.in_degree(node) == 0 and G.out_degree(node) > 0]

print("%d nodes\n%d leaf\n%d non-leaf\n\n%d heads" % (
    len(G.nodes()),
    len(leafs),
    len(nonleafs),
    len(heads)
))

13925 nodes
2107 leaf
11818 non-leaf

1958 heads


#### Define the train, valid, test splits

We define valid $\cup$ test as leaves, selected at each layer proportional to the number of leaves at the layer.

In [30]:
rand = np.random.RandomState(42)

splits = defaultdict(set)

splits['eval_thms'] = [rid for rid in retrieval_examples if rid in leafs]
splits['train_thms'] = [rid for rid in retrieval_examples if rid not in leafs]
splits['eval_refs'] = splits['eval_thms']
splits['train_refs'] = [x for x in id2ref if x not in splits['eval_refs']]

for k in splits:
    splits[k] = list(splits[k])
    print(k, len(splits[k]))

eval_thms 1551
train_thms 9022
eval_refs 1551
train_refs 13583


#### Verify that evaluation theorems are not referred in training

In [31]:
for x in tqdm(splits['eval_thms'], total=len(splits['eval_thms'])):
    for y in splits['train_refs']:
        if G.has_predecessor(y, x):
            print(id2ref[x]['title'], id2ref[y]['title'])

100%|██████████| 1551/1551 [00:04<00:00, 346.35it/s]


#### Randomly split evaluation into validation and test.

In [32]:
rand = np.random.RandomState(42)
perm = rand.permutation(len(splits['eval_thms']))

idx = len(splits['eval_thms'])//2
val_idxs = perm[:idx]

val_thms = [splits['eval_thms'][i] for i in perm[:idx]]
tst_thms = [splits['eval_thms'][i] for i in perm[idx:]]

In [33]:
final_splits = {
    'train': {},
    'valid': {},
    'test': {}
}

final_splits['train']['ref_ids'] = splits['train_refs']
final_splits['train']['examples'] = [(tid, 0) for tid in splits['train_thms']]

final_splits['valid']['ref_ids'] = splits['train_refs'] + splits['eval_refs']
final_splits['valid']['examples'] = [(tid, 0) for tid in val_thms]

final_splits['test']['ref_ids'] = splits['train_refs'] + splits['eval_refs']
final_splits['test']['examples'] = [(tid, 0) for tid in tst_thms]

In [34]:
for k in final_splits:
    print(k)
    for k2 in final_splits[k]:
        print(k2, len(final_splits[k][k2]))
    print()

train
ref_ids 13583
examples 9022

valid
ref_ids 15134
examples 775

test
ref_ids 15134
examples 776



In [35]:
js = {
    'dataset': dataset,
    'splits': final_splits,
}

import json
output_json = './naturalproofs_stacks.json'
with open(output_json, 'w') as f:
    json.dump(js, f, indent=4)