In [None]:
import pickle
import random

import glob
import math

import torch
from tqdm import tqdm


def get_traces(path):
    files = glob.glob(path)

    traces = []
    for file in tqdm(files):
        with open(file, "rb") as f:
            trace = pickle.load(f)
            traces.append(trace)
    return traces


# traces = get_traces('../experiments/runs/eval_loop/leandojo_eval_2023_11_08/17_37_14/traces/*')
# traces = get_traces('../experiments/runs/eval_loop/leandojo_eval_2023_11_10/12_32_48/traces/*')
traces = get_traces('../experiments/runs/eval_loop/goal_model_2023_11_17/18_11_05/traces/*')
# traces.extend(get_traces('../traces_2023-10-31_17:28/*'))

In [None]:
traces[1].tree.out_edges[10]


In [None]:
len(traces)

In [None]:
from refactor.proof_node import Status, ErrorNode

len([t for t in traces if t.tree.status == Status.FAILED])/ len(traces)

In [None]:
len([t for t in traces if t.tree.status == Status.PROVED])/ len(traces)

In [None]:
failed = [t for t in traces if t.tree.status == Status.FAILED and not isinstance(t.tree, ErrorNode)]

In [None]:
failed[2].nodes

In [None]:
trace = failed[28]
print (trace.num_expansions)
[(node.visit_count, node.is_explored) for node in trace.nodes.values()]

In [None]:
failed[28].tree.out_edges[-4]


In [None]:
failed[4]

In [None]:
from experiments.reprover.render_proof import render_full_trace, render_nx

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
render_full_trace(traces[0])

In [None]:
import torch

In [None]:
torch.cuda.mem_get_info('cuda:0')

In [None]:
from refactor.proof_node import InternalNode, ErrorNode, ProofFinishedNode

In [None]:
from pymongo import MongoClient


In [None]:
client = MongoClient()
db = client['lean_dojo']
collection = db['goal_data']

In [None]:
import torch
def add_goal_data(node, visits):
    steps = node.distance_to_proof
    # todo add up_score as new estimate? only after certain visit_threshold

    datum = {
        'goal': node.goal,
        'distance_to_proof': steps,
        'visits': visits[node.goal],
        'local_visits': len(node.out_edges) if node.out_edges else 0,
        'score': node.up_score.item() if isinstance(node.up_score, torch.Tensor) else node.up_score
    }

    return datum

for trace in tqdm(traces):
    if isinstance(trace.tree, ErrorNode):
        continue
    nodes = trace.nodes
    nodes[trace.tree.goal] = trace.tree

    updated_visit_count = {node: nodes[node].visit_count for node in nodes.keys()}

    for goal, node in nodes.items():
        for a in node.ancestors:
            updated_visit_count[a] += node.visit_count

    for node in nodes:
        step_datum = add_goal_data(nodes[node], updated_visit_count)
        if step_datum:
            collection.insert_one(step_datum)
            # goal_step_data.append(step_datum)


In [None]:
collection = db['edge_data']

In [None]:
def get_edge_data(trace):
    data = []
    for i, edge in enumerate(trace.tac_trace):
        datum = {
            'iteration': 0,
            'step': i,
            'top_goal': trace.theorem,
            'goal': edge.src.goal,
            'tactic': edge.tactic,
            'goal_prob': edge.src.cumulative_logprob,
            'tac_prob': edge.logprob,
            'distance_to_proof': edge.distance_to_proof(),
            'visits': edge.visit_count(),
            'time': edge.time,}
        # add children of edge
        if len(edge.dst) == 1 and isinstance(edge.dst[0], ErrorNode):
            # todo could record error message for e.g. self-correcting proof approach>
            datum['outcome'] = ['Error']
        elif len(edge.dst) == 1 and isinstance(edge.dst[0], ProofFinishedNode):
            datum['outcome'] = ['Proven']
        else:
            outcome = [d.goal for d in edge.dst]
            datum['outcome'] = outcome
        data.append(datum)
    return data

In [None]:
# edge_data = []

for trace in tqdm(traces):
    if isinstance(trace.tree, ErrorNode):
        continue
    collection.insert_many(get_edge_data(trace))
    # edge_data.extend(get_edge_data(trace))



In [None]:
len(traces[6].tac_trace)

In [None]:
traces[0].num_expansions

In [None]:
def add_rand_idx(collection):
    collection.update_many({'rand_idx': {'$exists': False} },
        [{'$set':
                {'rand_idx': {
                    '$function': {
                        'body': 'function() {return Math.random();}',
                        'args': [],
                        'lang': "js"
                    }
                    }}
        }]
        )

    collection.create_index('rand_idx')
    return


In [None]:
rank_collection = db['tac_ranks']

In [None]:
def transform_goal(goal_datum, max_len=10, visit_threshold=2048):
    proof_len = goal_datum['distance_to_proof']
    if proof_len < max_len:
        return {'goal': goal_datum['goal'], 'target': (max_len + 1) - goal_datum['distance_to_proof']}
    elif proof_len < math.inf:
        return {'goal': goal_datum['goal'], 'target': 1}
    elif goal_datum['visits'] >= visit_threshold:
        return {'goal': goal_datum['goal'], 'target': 0}
    else:
        return None


In [None]:
def transform_goal_proven(goal_datum, visit_threshold=2048):
    proof_len = goal_datum['distance_to_proof']
    if proof_len < math.inf:
        return {'goal': goal_datum['goal'], 'target': 1}
    elif goal_datum['visits'] >= visit_threshold:
        return {'goal': goal_datum['goal'], 'target': 0}
    else:
        return None



In [None]:
goal_collection = db['goal_data']

goal_len_collection = db['goal_len_task']

for datum in tqdm(goal_collection.find()):
    len_data = transform_goal(datum)
    if len_data:
        goal_len_collection.insert_one(len_data)


In [None]:
goal_proven_task = db['goal_proven_task']

for datum in tqdm(goal_collection.find()):
    len_data = transform_goal_proven(datum)
    if len_data:

        goal_proven_task.insert_one(len_data)


In [None]:
add_rand_idx(goal_len_collection)
add_rand_idx(goal_proven_task)


In [None]:

# create pairs of winners/losers based on edges from a given goal, and maintain tac probs for each
# e.g. edges = find({'goal': 'goal'}).edges
def rank_edges(goal, edges):
    valid_edges = [edge for edge in edges if not isinstance(edge.dst[0], ErrorNode)]
    invalid_edges = [edge for edge in edges if isinstance(edge.dst[0], ErrorNode)]

    # rank all valid_edges above all invalid_edges
    w_l = [{'goal': goal, 'winner': w.tactic, 'winner_prob': w.logprob, 'loser': l.tactic, 'loser_prob': l.logprob, 'type': 'valid_rank'} for w in valid_edges for l in invalid_edges]

    # from valid_edges, rank proven goals above non_proven valid goals
    proven_edges = [edge for edge in valid_edges if edge.distance_to_proof() < math.inf]
    success_non_proven_edges = [edge for edge in valid_edges if edge.distance_to_proof() == math.inf]

    w_l.extend([{'goal': goal, 'winner': w.tactic, 'winner_prob': w.logprob, 'loser': l.tactic, 'loser_prob': l.logprob, 'type': 'proven_rank'} for w in proven_edges for l in success_non_proven_edges])

    # from proven edges, rank based on distance_to_proof, then execution time
    ranked_proofs = sorted(proven_edges, key=lambda x: (x.distance_to_proof(), x.time))

    w_l.extend(
         [{ 'goal': goal, 'winner': ranked_proofs[i].tactic,
            'winner_prob': ranked_proofs[i].logprob,  'loser': ranked_proofs[j].tactic, 'loser_prob': ranked_proofs[j].logprob,
            'type': 'time_len_rank' } for i in range(len(ranked_proofs)) for j in
          range(i + 1, len(ranked_proofs))])


    # among successful without proof, rank those that lead to the same outcome based on time only
    for i, edge in enumerate(success_non_proven_edges):
        same_outcome_ranks = []
        for j in range((i + 1), len(success_non_proven_edges)):
            edge_2 = success_non_proven_edges[j]
            edge_1_outcome = [g.goal for g in edge.dst] if isinstance(edge.dst[0], InternalNode) else ['Error'] if isinstance(edge.dst[0], ErrorNode) else ['Proven']
            edge_2_outcome = [g.goal for g in edge_2.dst] if isinstance(edge_2.dst[0], InternalNode) else ['Error'] if isinstance(edge_2.dst[0], ErrorNode) else ['Proven']
            if set(edge_1_outcome) == set(edge_2_outcome):
                if edge.time < edge_2.time:
                    same_outcome_ranks.append({'goal': goal, 'winner': edge.tactic, 'winner_prob':  edge.logprob, 'loser': edge_2.tactic, 'loser_prob':  edge_2.logprob, 'type': 'same_outcome'})
                else:
                    same_outcome_ranks.append({'goal': goal, 'winner': edge_2.tactic, 'winner_prob':  edge_2.logprob, 'loser': edge.tactic, 'loser_prob':  edge.logprob, 'type': 'same_outcome'})

        w_l.extend(same_outcome_ranks)

    if w_l:
        rank_collection.insert_many(w_l)
    return



In [None]:
# all_goals = [edge['goal'] for edge in tqdm(collection.find())]

In [None]:
# all_goals = set(all_goals)

In [None]:
from tqdm import tqdm

# data = []
# for goal in tqdm(all_goals):
for trace in tqdm(traces):
    # test_edges = [edge for edge in collection.find({'goal': goal})]
    # goal, winners, losers = rank_edges(goal=goal, edges=test_edges)
    if isinstance(trace.tree, ErrorNode):
        continue
    nodes = trace.nodes
    nodes[trace.tree.goal] = trace.tree

    for node in nodes.values():
        if node.out_edges:
            rank_edges(goal=node.goal, edges=node.out_edges)
    # data.append((goal, winners, losers))

In [None]:
add_rand_idx(rank_collection)

In [None]:
goal_collection = db['goal_data']

In [None]:
add_rand_idx(goal_collection)

In [None]:
goal, winners, losers = data[8]


len(winners)

In [None]:
i = 419
print (winners[i])
print (losers[i])

In [None]:
i = -6
print (winners[i])
losers[i]

In [None]:
# todo method to reconstruct search tree based on edge data above
# run normal search process, replace run_tac with outcome -> edge, replace get_goals with goal, replace get_tactics with tactic
# useful for reward based goal models

In [None]:
# todo how to merge different attempts of same proof?
# For goal data, if proof length is lower, take that data point. If failed, and visit count higher, replace with that as well
# I.e. every new attempt, add all new goals, and also update existing goals with above criteria

# For edge data...
# Assume all valid/invalid edges are still valid/invalid, then those rankings are fine
# Rankings from proven/success could be changed if success turns out to be a proof..
# Rankings within proof could also change, if shorter proof from children is found
# Seems small/unlikely for this to make much of a difference. Worst case is a longer proof is ranked better than a shorter/slower one

# Don't just keep best trace, since we may discard useful old goals
# Best trace given by the trace with the shortest proof...


# todo check logits of forward match those from generation

# todo train scripts for eval models

# todo htps

# todo add BFS, bestfs


In [None]:
# run in environment