In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import List
from transformers import PreTrainedTokenizer, PreTrainedModel


def test_generation(state: str, premises: List[str], tokenizer: PreTrainedTokenizer, model: PreTrainedModel):
    # inpus is just a concatenation of premises and proof state
    input = "\n\n".join(premises + [state])
    print("------ INPUT ------\n", input)

    # tokenize
    tokenized_input = tokenizer(input, return_tensors="pt", max_length=2300, truncation=True)

    # Generate a single tactic.
    tactic_ids = model.generate(tokenized_input.input_ids, max_length=1024)
    tactic = tokenizer.decode(tactic_ids[0], skip_special_tokens=True)
    print("\n------ OUTPUT ------")
    print(tactic, end="\n\n")

    # Generate multiple tactics via beam search.
    tactic_candidates_ids = model.generate(
        tokenized_input.input_ids,
        max_length=1024,
        num_beams=4,
        length_penalty=0.0,
        do_sample=False,
        num_return_sequences=4,
        early_stopping=False,
    )
    tactic_candidates = tokenizer.batch_decode(tactic_candidates_ids, skip_special_tokens=True)
    for tac in tactic_candidates:
        print(tac)

In [3]:
from transformers import AutoTokenizer, T5ForConditionalGeneration

# load tokenizer and generator
model_name = "kaiyuy/leandojo-lean3-retriever-tacgen-byt5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
generator = T5ForConditionalGeneration.from_pretrained(model_name)

state = "n : ℕ\n⊢ gcd n n = n"
retrieved_premises = [
    "def <a>nat.gcd</a> : nat → nat → nat\n| 0        y := y\n| (succ x) y := have y % succ x < succ x, from mod_lt _ $ succ_pos _,\n                gcd (y % succ x) (succ x)",
    "@[simp] theorem <a>nat.mod_self</a> (n : nat) : n % n = 0",
]

In [None]:
test_generation(state, retrieved_premises, tokenizer, generator)

## Run eval step for baseline model

In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
from reprover.retrieval.datamodule import RetrievalDataModule

model_name = "kaiyuy/leandojo-lean3-retriever-tacgen-byt5-small"
corpus_path = "../data/leandojo_benchmark_4/corpus.jsonl"

data_module = RetrievalDataModule(
    "../data/leandojo_benchmark_4/novel_premises",
    corpus_path,
    num_negatives=2,
    num_in_file_negatives=1,
    context_tokenizer_name=model_name,
    premise_tokenizer_name=model_name,
    batch_size=1,
    eval_batch_size=1,
    max_seq_len=1024,
    num_workers=0,
)
data_module.setup()

WARINING: restricting Corpus to 100 lines


[32m2024-03-09 21:00:54.083[0m | [1mINFO    [0m | [36mreprover.retrieval.datamodule[0m:[36m_load_data[0m:[36m52[0m - [1mLoading data from ../data/leandojo_benchmark_4/novel_premises/train.json[0m
100%|██████████| 98514/98514 [00:10<00:00, 8985.66it/s] 
[32m2024-03-09 21:01:13.928[0m | [1mINFO    [0m | [36mreprover.retrieval.datamodule[0m:[36m_load_data[0m:[36m92[0m - [1mLoaded 279112 examples.[0m
[32m2024-03-09 21:01:13.934[0m | [1mINFO    [0m | [36mreprover.retrieval.datamodule[0m:[36m_load_data[0m:[36m52[0m - [1mLoading data from ../data/leandojo_benchmark_4/novel_premises/val.json[0m
100%|██████████| 2000/2000 [00:00<00:00, 3274.15it/s]
[32m2024-03-09 21:01:14.630[0m | [1mINFO    [0m | [36mreprover.retrieval.datamodule[0m:[36m_load_data[0m:[36m92[0m - [1mLoaded 9705 examples.[0m
[32m2024-03-09 21:01:14.631[0m | [1mINFO    [0m | [36mreprover.retrieval.datamodule[0m:[36m_load_data[0m:[36m52[0m - [1mLoading data from ../data/le

In [8]:
import torch
from reprover.retrieval.model import PremiseRetriever

device = torch.device("mps")

generator = PremiseRetriever(
    model_name=model_name,
    lr=3e-4,
    warmup_steps=100,
    max_seq_len=1024,
    num_retrieved=100,
).to(device)

In [9]:
generator.load_corpus(corpus_path)

for batch in data_module.val_dataloader():
    break

generator.retrieve(
    state=[ctx.serialize() for ctx in batch["context"]],
    file_name=batch["file_path"],
    theorem_full_name=batch["full_name"],
    theorem_pos=[ctx.theorem_pos for ctx in batch["context"]],
    k=1,
    reindex_batch_size=64,
)

WARINING: restricting Corpus to 100 lines


[32m2024-03-09 21:02:06.914[0m | [1mINFO    [0m | [36mreprover.retrieval.model[0m:[36mreindex_corpus[0m:[36m91[0m - [1mRe-indexing the retrieval corpus[0m
100%|██████████| 2/2 [00:03<00:00,  1.63s/it]


([[Premise(path='lake-packages/lean4/src/lean/Init/Prelude.lean', full_name='HSub', code='class HSub (α : Type u) (β : Type v) (γ : outParam (Type w)) where\n  \n  hSub : α → β → γ')]],
 [[0.6442855596542358]])

In [1]:
from colbert.data import Queries
from colbert.infra import ColBERTConfig
from colbert import Searcher, Indexer


config = ColBERTConfig(
    nbits=2,
    root="../experiments/test",
)
indexer = Indexer(checkpoint="../checkpoints/colbertv2.0", config=config)
indexer.index(
    name="test_index",
    collection="../data/leandojo_benchmark_4/novel_premises/colbert_collection_100.tsv",
    overwrite=True,
)



[Mar 09, 21:33:44] #> Note: Output directory /Users/ykapushev/Work/ReProver/notebooks/experiments/default/indexes/test_index already exists


[Mar 09, 21:33:44] #> Will delete 1 files already at /Users/ykapushev/Work/ReProver/notebooks/experiments/default/indexes/test_index in 20 seconds...
#> Starting...
{
    "query_token_id": "[unused0]",
    "doc_token_id": "[unused1]",
    "query_token": "[Q]",
    "doc_token": "[D]",
    "ncells": null,
    "centroid_score_threshold": null,
    "ndocs": null,
    "load_index_with_mmap": false,
    "index_path": null,
    "index_bsize": 64,
    "nbits": 2,
    "kmeans_niters": 20,
    "resume": false,
    "similarity": "cosine",
    "bsize": 64,
    "accumsteps": 1,
    "lr": 1e-5,
    "maxsteps": 400000,
    "save_every": null,
    "warmup": 20000,
    "warmup_bert": null,
    "relu": false,
    "nway": 64,
    "use_ib_negatives": true,
    "reranker": false,
    "distillation_alpha": 1.0,
    "ignore_scores": false,
    "model_name": null,
   

100%|██████████| 2/2 [00:00<00:00,  2.10it/s]
OMP: Error #15: Initializing libomp.dylib, but found libomp.dylib already initialized.
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://openmp.llvm.org/


[Mar 09, 21:34:07] [0] 		 avg_doclen_est = 27.53535270690918 	 len(local_sample) = 99
[Mar 09, 21:34:07] [0] 		 Creating 512 partitions.
[Mar 09, 21:34:07] [0] 		 *Estimated* 2,725 embeddings.
[Mar 09, 21:34:07] [0] 		 #> Saving the indexing plan to /Users/ykapushev/Work/ReProver/notebooks/experiments/default/indexes/test_index/plan.json ..
Clustering 2590 points in 128D to 512 clusters, redo 1 times, 20 iterations
  Preprocessing in 0.00 s


In [None]:
config = ColBERTConfig(
    root="../experiments/test",
)
searcher = Searcher(index="test_index", config=config)
queries = Queries("../data/leandojo_benchmark_4/novel_premises/colbert_queries.json")
ranking = searcher.search_all(queries, k=100)
ranking.save("test_ranking.tsv")