In [None]:
import torch
import random
import numpy as np

from tqdm import tqdm
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
)

from datasets import Dataset
from typing import List , Optional

: 

In [None]:
# https://arxiv.org/abs/2109.08203
random.seed(3407)
np.random.seed(3407)
torch.manual_seed(3407)

In [None]:
%pip install lean
%pip install lean_dojo

In [None]:
from lean_dojo import *

In [None]:
repo = LeanGitRepo(
    'https://github.com/leanprover-community/mathlib4',
    ''
)

In [None]:
traced_repo = trace(repo)

In [None]:
theorems = traced_repo.get_traced_theorems()
print(f'{len(theorems)} theorems traced')

In [None]:
static_tactic_pairs = []

for theorem in tqdm(theorems):
  for t in theorem.get_traced_tactics():
    static_tactic_pairs.append(
        {'state': t.state_before,
         'tactic': t.tactic
         })


In [None]:
st = static_tactic_pairs[0]
print(st['state'])
print(st['tactic'])

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained('google/byt5-small')
tokenizer = AutoTokenizer.from_pretrained('google/byt5-small')

In [None]:
dataset = Dataset.from_list(state_tactic_pairs).shuffle().select(range(10000))

def tokenize(examples):
  model_inputs = tokenizer(examples['state'] , max_length=2048 , truncation=True)
  labels = tokenizer(test_target=examples['tactic'] , max_length=2048 , truncation=True)
  model_inputs['labels'] = labels['input_ids']
  return model_inputs

tokenized_dataset = dataset.map(tokenize , batched=True)
print(tokenized_dataset[0])

In [None]:
# example , dont run
training_args = Seq2SeqTrainingArguments(
    output_dir='./results',
    learning_rate=1e-5,
    per_device_train_batch_size=8,
    max_steps=2,
    use_cpu=True
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer , model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)
trainer.train()

In [None]:
tokenizer = AutoTokenizer.from_pretrained('kaiyuy/leandojo-lean4-tacgen-byt5-small')
model = AutoModelForSeq2SeqLM.from_pretrained('kaiyuy/leandojo-lean4-tacgen-byt5-small')

In [None]:
def generate_tactic(state: str) -> str:
  tokenized_state = tokenizer(state , return_tensor='pt')
  tactic_ids = model.generate(tokenized_state.input_ids , max_length=1024)
  tactic = tokenizer.decode(tactic_ids[0] , skip_special_tokens=True)
  print(tactic , end='\n\n')

In [None]:
def generate_tacticS(state: str , k: int = 16) -> List[str]:
  tokenized_state = tokenizer(state , return_tensor='pt')
  tactic_candidates_ids = model.generate(
      tokenized_state.input_ids ,
      max_length=256,
      num_beams=k,
      length_penalty=0.0,
      do_sample=False,
      num_return_sequences=k,
      output_scores=True,
      early_stopping=True
  )
  tactic_candidates = tokenizer.batch_decode(
      tactic_candidates_ids,
      skip_special_tokens=True
  )
  return tactic_candidates

In [None]:
repo = LeanGitRepo(
    'https://github.com/yangky11/lean4-example',
    ''
)
theorem = Theorem(repo , 'Lean4Example.lean' , 'add_abc')

In [None]:
Tactic = str
Proof = List[Tactic]

num_candidates = 16
depth_limits = 3

def search(state: TacticState , depth: int) -> Optional(Proof):
  if depth >= depth_limit:
    return None

  tactics = generate_tacticS(state.pp , num_candidates)

  for tac in tactics:
    next_state = dojo.run_tac(state , tac)
    if isinstance(next_state , ProofFinished):
      return [tac] # found proof
    elif not isinstance(next_state , LeanError):
      # recursive dfs'
      subproof = search(next_state , depth + 1)
      if subproof is not None:
        return [tac] + subproof

  return None

Basically works by:
- LeanDojo extracts state-tactic pairs from mathlib