In [None]:
from tqdm import tqdm
import pickle
import glob

bestfs_path = 'runs/end_to_end/original/2024_02_15/00_47_29/traces/0/'
bfs_path = 'runs/end_to_end/bfs/2024_04_05/18_45_30/traces/0/'

def load_traces(path):
    files = glob.glob(path + '*', recursive=True)
    traces = []

    for file in tqdm(files):
        traces.append(pickle.load(open(file, "rb")))

    return traces

In [None]:
bfs_traces = load_traces(bfs_path)
bestfs_traces = load_traces(bestfs_path)

In [None]:
from experiments.end_to_end.proof_node import ErrorNode, Status
import math

data = []
visits = {}


def add_trace(trace):
    nodes = trace.nodes
    if isinstance(trace.tree, ErrorNode):
        return
    nodes[trace.tree.goal] = trace.tree

    visits = {node: nodes[node].visit_count for node in nodes.keys()}

    for goal, node in nodes.items():
        for a in node.ancestors:
            visits[a] += node.visit_count

    for node in trace.nodes.values():
        if not isinstance(node, ErrorNode):
            data.append((node, visits[node.goal]))


In [None]:
[add_trace(t) for t in bestfs_traces]
len(data)

In [None]:
# takeaway:
# most nodes have visit counts under (89%)
# 11% of nodes have visit counts over 64
# 3.6% of nodes are failed, 1.1% proven

In [None]:
# plot distribution of node visit counts, up to 4096 visits

import matplotlib.pyplot as plt

plt.hist([d[1] for d in data], bins=range(0, 512, 64))


In [None]:
num_proven = len([d[0] for d in data if d[0].status == Status.PROVED]) / len(data)
num_proven

In [None]:
num_failed = len([d[0] for d in data if d[0].status == Status.FAILED]) / len(data)
num_failed

In [None]:
num_open = len([d[0] for d in data if d[0].status == Status.OPEN]) / len(data)
num_open

In [None]:
# number with visit count over 64
num_over_64 = len([d[0] for d in data if d[1] >= 64]) / len(data)
num_over_64

In [None]:
# number with visit count over 128
num_over_128 = len([d[0] for d in data if d[1] >= 128]) / len(data)
num_over_128

In [None]:
# number with visit count over 256
num_over_256 = len([d[0] for d in data if d[1] >= 256]) / len(data)
num_over_256

In [None]:
bfs_proved = [a.theorem.full_name for a in bfs_traces if a.proof]
best_fs_proved = [a.theorem.full_name for a in bestfs_traces if a.proof]

In [None]:
len(bfs_proved), len(best_fs_proved)

In [None]:
# get the intersection of proved theorems, and number unique to each set

bfs_proved_set = set(bfs_proved)
best_fs_proved_set = set(best_fs_proved)

intersection = bfs_proved_set.intersection(best_fs_proved_set)

bfs_unique = bfs_proved_set.difference(best_fs_proved_set)

best_fs_unique = best_fs_proved_set.difference(bfs_proved_set)

len(intersection), len(bfs_unique), len(best_fs_unique)

In [None]:
# train data path
train_path = 'runs/train_traces/0/'

In [None]:
train_files = glob.glob(train_path + '*', recursive=True)

In [None]:
from experiments.end_to_end.proof_node import Status

In [None]:
# takeaway: fails distributed zipfian, with a few proofs having most fails
# fails with no out_edges are usually root nodes which have errored out
# fails with out_edges are usually nodes which haven't been fully visited, implying an error in one of the expansions

In [None]:
weird_fails = []
fail_dist = []
true_fails = 0
no_edges = []
for file in tqdm(train_files):
    trace = pickle.load(open(file, 'rb'))

    fail_nodes = [node for node in trace.nodes.values() if node.status == Status.FAILED]
    for node in fail_nodes:
        if node.out_edges:
            if all([any([c.status == Status.FAILED for c in child.dst]) for child in
                    node.out_edges]) and node.visit_count >= node.max_expansions:
                true_fails += 1
            else:
                weird_fails.append(node)
        else:
            no_edges.append(node)

    fail_dist.append(len(fail_nodes))


In [None]:
len(weird_fails), true_fails, len(no_edges)

In [None]:
[a.visit_count for a in weird_fails]

In [None]:
[(type(d.dst[0]), len(d.dst)) for d in weird_fails[0].out_edges]


In [None]:
weird_fails[0].in_edges

In [None]:
weird_fails[0]

In [None]:
# plot histogram of fail_dist, excluding those with 0 value

import matplotlib.pyplot as plt

plt.hist(fail_dist, bins=range(1, max(fail_dist) + 1))
plt.show()


In [None]:
[type(d.dst[0]) for d in [node for node in trace.nodes.values() if node.status == Status.FAILED][1].out_edges]
# [d.dst[0] for d in [node for node in trace.nodes.values() if node.status == Status.FAILED][0].out_edges]
