### Dataset splits

Form dataset splits using the reference graph.

In [1]:
%pylab inline
import pandas as pd
import glob
import pickle
import os
import json
from pprint import pprint as pp
from collections import defaultdict
import torch
from tqdm import tqdm

Populating the interactive namespace from numpy and matplotlib


In [2]:
raw_ds = json.load(open('/path/to/dataset.json', 'r'))

### Form the reference graph

$G=(V,E)$
- $v\in V$: theorem, definition, other page
- $(u, v)\in E$: $u$ occurs in the statement or a proof of $v$

In [3]:
refs = raw_ds['dataset']['theorems'] + raw_ds['dataset']['definitions'] + raw_ds['dataset']['other']

graph = defaultdict(list)

id2ref = {}
ref2id = {}
for r in refs:
    ref2id[r['title']] = r['id']
    id2ref[r['id']] = r
    
title2proof = {}
for p in raw_ds['dataset']['proofs']:
    title2proof[p['title']] = p
    
pairs = []
cycles = []
for r1 in refs:
    
    # Make an edge for each reference in the _statement_
    for r2 in r1['refs']:
        
        r1id = r1['id']
        r2id = ref2id[r2]
        
        if r1id != r2id:
            graph[r2id].append(r1id)
            
            pairs.append((r2id, r1id))
            
            if r2id in graph[r1id]:
                cycles.append(tuple(sorted((r2id, r1id))))
                
    # Make an edge for each reference in the _proof_ (when available)
    if r1['type'] == 'theorem' and r1['has_proof']:
        for title in r1['proof_titles']:
            proof = title2proof[title]
            
            for r2 in proof['refs']:                
                r1id = r1['id']
                r2id = ref2id[r2]
                if r1id != r2id:
                    graph[r2id].append(r1id)
                    
                    pairs.append((r2id, r1id))

                    if r2id in graph[r1id]:
                        cycles.append(tuple(sorted((r2id, r1id))))

cycles = set(cycles)
print("%d 1-cycles" % (len(cycles)))

877 1-cycles


In [4]:
import networkx

G = networkx.DiGraph(graph)
leafs = [node for node in G.nodes() if G.in_degree(node) != 0 and G.out_degree(node)==0]
nonleafs = [node for node in G.nodes() if G.in_degree(node) == 0 or G.out_degree(node) != 0]
heads = [node for node in G.nodes() if G.in_degree(node) == 0 and G.out_degree(node) > 0]

print("%d nodes\n%d leaf\n%d non-leaf\n\n%d heads" % (
    len(G.nodes()),
    len(leafs),
    len(nonleafs),
    len(heads)
))

30681 nodes
12396 leaf
18285 non-leaf

1392 heads


#### BFS layers

Form BFS layers, count the number of nodes, example-worthy theorems (has proof(s) + contents), and 1-cycles.

In [6]:
print("total nodes %d\n" % len(G.nodes()))

incycle = set()
for a, b in cycles:
    incycle.add(a)
    incycle.add(b)

# theorems that correspond to examples (e.g. has a proof, contents)
tid2eid = {}
for item in raw_ds['dataset']['retrieval_examples']:
    tid2eid[item['theorem_id']] = item['example_id']
    
layers = defaultdict(set)
nleafs = []

seen = set()
for node in heads:
    layers[0].add(node)
    seen.add(node)
    
layer = 0
print('layer', 'nodes', 'thms', 'leaf_thms', sep='\t')

while len(layers[layer]) > 0:
    thms = [x for x in layers[layer] if x in tid2eid]
    leaf_thms = [x for x in layers[layer] if x in tid2eid
        and x in leafs
        and (x not in incycle)
    ]
    nleafs.append(len(leaf_thms))
    
    print(layer, len(layers[layer]), len(thms), len(leaf_thms), sep='\t')
    for node in layers[layer]:
        for child in G.successors(node):
            if child not in seen:
                layers[layer+1].add(child)
                seen.add(child)
    layer += 1

nleafs.append(0)
    
nleafs = np.array(nleafs)

total nodes 30681

layer	nodes	thms	leaf_thms
0	1392	0	0
1	10850	5777	2376
2	11713	5644	2581
3	5239	1898	788
4	1322	252	107
5	141	25	12
6	14	1	0


#### Define the train, valid, test splits

We define valid $\cup$ test as leaves, selected at each layer proportional to the number of leaves at the layer.

In [7]:
budget = 2200
leaf_frac = nleafs/nleafs.sum()

In [8]:
rand = np.random.RandomState(42)

splits = defaultdict(set)

for layer in range(len(layers)):
        
    # get number of eval leaves for this layer
    nleaf = int(budget*leaf_frac[layer])
    
    # randomly sample `nleaf` leaf theorems
    leaf_thms = [x for x in layers[layer] if x in tid2eid
        and x in leafs
        and (x not in incycle)
    ]
    perm = rand.permutation(len(leaf_thms))
    eval_thms = [leaf_thms[i] for i in perm[:nleaf]]
    
    # collect as evaluation theorems and references
    for x in eval_thms:
        splits['eval_thms'].add(x)
        splits['eval_refs'].add(x)
    
    # collect all other items as training data
    eval_thms_set = set(eval_thms)
    for x in layers[layer]:
        if x not in eval_thms_set:
            splits['train_refs'].add(x)
            if x in tid2eid:
                splits['train_thms'].add(x)
                
for k in splits:
    splits[k] = list(splits[k])
    print(k, len(splits[k]))

train_refs 28473
eval_thms 2198
eval_refs 2198
train_thms 11399


#### Verify that evaluation theorems are not referred in training

In [9]:
for x in tqdm(splits['eval_thms'], total=len(splits['eval_thms'])):
    for y in splits['train_refs']:
        if G.has_predecessor(y, x):
            print(id2ref[x]['title'], id2ref[y]['title'])

100%|██████████| 2198/2198 [00:42<00:00, 51.90it/s]


#### Randomly split evaluation into validation and test.

In [10]:
rand = np.random.RandomState(42)
perm = rand.permutation(len(splits['eval_thms']))

idx = len(splits['eval_thms'])//2
val_idxs = perm[:idx]

val_thms = [splits['eval_thms'][i] for i in perm[:idx]]
tst_thms = [splits['eval_thms'][i] for i in perm[idx:]]


#### Convert theorem ids to example ids

In [12]:
tid2eid = {}
for item in raw_ds['dataset']['retrieval_examples']:
    tid2eid[item['theorem_id']] = item['example_id']

In [13]:
final_splits = {
    'train': {},
    'valid': {},
    'test': {}
}

final_splits['train']['ref_ids'] = splits['train_refs']
final_splits['train']['example_ids'] = [tid2eid[t] for t in splits['train_thms']]

final_splits['valid']['ref_ids'] = splits['train_refs'] + splits['eval_refs']
final_splits['valid']['example_ids'] = [tid2eid[t] for t in val_thms]

final_splits['test']['ref_ids'] = splits['train_refs'] + splits['eval_refs']
final_splits['test']['example_ids'] = [tid2eid[t] for t in tst_thms]

In [14]:
for k in final_splits:
    print(k)
    for k2 in final_splits[k]:
        print(k2, len(final_splits[k][k2]))
    print()

train
ref_ids 28473
example_ids 11399

valid
ref_ids 30671
example_ids 1099

test
ref_ids 30671
example_ids 1099



In [15]:
raw_ds['splits'] = final_splits

import json
with open('/path/to/dataset.json', 'w') as f:
    json.dump(raw_ds, f)