In [None]:
import heapq
import os
import pickle
import sys
import time

import ray
import torch
from lean_dojo import (
    Pos,
    Dojo,
    Theorem,
    LeanGitRepo,
    ProofFinished,
    DojoInitError,
    DojoCrashError,
    DojoHardTimeoutError,
)
from lean_dojo.constants import LEAN3_DEPS_DIR, LEAN4_DEPS_DIR
from ray.util.actor_pool import ActorPool

from common import zip_strict
from generator.model import RetrievalAugmentedGenerator
from prover.new_search_tree import *

ckpt_path = 'gen.ckpt'

tac_gen = RetrievalAugmentedGenerator.load(
    ckpt_path, device=torch.device("cuda"), freeze=True
)


In [None]:
# new_tok = ['<critic>', '<provable>', '<unprovable>']
# tac_gen.tokenizer.add_tokens(new_tok)

In [None]:
len(tac_gen.tokenizer)

In [None]:
# tac_gen.generator.resize_token_embeddings(len(tac_gen.tokenizer), pad_to_multiple_of=1)


In [None]:
tac_gen.generator

In [None]:
# use existing sentinel/extra tokens for goal task
critic_tok = '<extra_id_0>'
provable_tok = '<extra_id_1>'
unprovable_tok = '<extra_id_2>'

In [None]:
critic_id = tac_gen.tokenizer.encode(critic_tok)[0]
provable_id = tac_gen.tokenizer.encode(provable_tok)[0]
unprovable_id = tac_gen.tokenizer.encode(unprovable_tok)[0]

In [None]:
state = critic_tok + 'C : Type u₁,\n_inst_1 : category C,\nD : Type u₂,\n_inst_2 : category D,\nF : C ⥤ D,\nW X Y Z : C,\nf : W ⟶ X,\ng : W ⟶ Y,\nh : X ⟶ Z,\ni : Y ⟶ Z,\n_inst_3 : reflects_colimit (span f g) F,\ne : f ≫ h = g ≫ i,\nH : is_pushout (F.map f) (F.map g) (F.map h) (F.map i)\n⊢ comm_sq f g h i'

In [None]:
tokenized_state = tac_gen.tokenizer(
    state,
    padding="longest",
    max_length=1024,
    truncation=True,
    return_tensors="pt",
)

In [None]:
tokenized_state

In [None]:
tac_gen.tokenizer.decode(tokenized_state['input_ids'][0])

In [None]:
ids = tokenized_state.input_ids.cuda()

In [None]:
# restrict output to just be provable and unprovable
bad_ids = [[i] for i in range(len(tac_gen.tokenizer)) if (i != provable_id and i != unprovable_id)]

In [None]:
bad_ids

In [None]:
tactic_ids = tac_gen.generator.generate(ids,
                                        max_new_tokens=2,
                                        bad_words_ids=bad_ids,
                                        return_dict_in_generate=True,
                                        output_scores=True,)

In [None]:
tactic_ids.scores

In [None]:
provable_score = torch.exp(tactic_ids.scores[0][0][provable_id])
unprovable_score = torch.exp(tactic_ids.scores[0][0][unprovable_id])

In [None]:
tactic_ids.sequences[0]

In [None]:
tac_gen.tokenizer.batch_decode(tactic_ids[0], skip_special_tokens=True)



In [None]:
provable_score = tactic_ids.scores[0][0][provable_id]

In [None]:
unprovable_score = tactic_ids.scores[0][0][unprovable_id]

In [None]:
from goal_model.datamodule import GoalDataModule

critic_tok = '<extra_id_0>'
provable_tok = '<extra_id_1>'
unprovable_tok = '<extra_id_2>'

module = GoalDataModule(data_path='goal_data.pk',
                        max_seq_len=1024,
                        batch_size=8,
                        model_name="kaiyuy/leandojo-lean3-tacgen-byt5-small",
                        critic_tok=critic_tok,
                        provable_tok=provable_tok,
                        unprovable_tok=unprovable_tok,
                        num_workers=4,
                        val_data_path=None,
                        eval_batch_size=8,
                        visit_threshold=256
                        )



module.setup('fit')
loader = module.train_dataloader()
print (next(iter(loader)))


In [None]:
len([d for d in module.ds_train.data if d['proved'] == 1])

In [None]:
import pickle
with open('lightning_logs/version_18/events.out.tfevents.1697004955.pc.48076.0', 'rb') as f:
    res = pickle.load(f)