NeurIPS 2023 Tutorial on Machine Learning for Theorem Proving
=============================================================

In [53]:
import torch
import random
import numpy as np
from tqdm import tqdm
from datasets import Dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
)
from typing import List, Dict, Optional

# https://arxiv.org/abs/2109.08203
random.seed(3407)
np.random.seed(3407)
torch.manual_seed(3407)

<torch._C.Generator at 0x7fc590d5cff0>

## Roadmap

## Data Extraction

We use [LeanDojo](https://leandojo.org/) to extract state-tactic pairs from mathlib.

### Trace the Repo

In [2]:
from lean_dojo import *

repo = LeanGitRepo(
    "https://github.com/leanprover-community/mathlib4",
    "3ce43c18f614b76e161f911b75a3e1ef641620ff",
)

repo.show()

In [3]:
traced_repo = trace(repo)  # A few minutes, depending on #CPUs.

[32m2023-12-03 08:49:04.665[0m | [1mINFO    [0m | [36mlean_dojo.data_extraction.trace[0m:[36mtrace[0m:[36m182[0m - [1mLoading the traced repo from /home/kaiyu/.cache/lean_dojo/leanprover-community-mathlib4-3ce43c18f614b76e161f911b75a3e1ef641620ff/mathlib4[0m
2023-12-03 08:49:06,885	INFO worker.py:1664 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8266 [39m[22m
100%|████████████████████████████████████████████████████████████████████████████████| 4462/4462 [08:36<00:00,  8.65it/s]
Following Github server redirection from /repos/mhuisi/lean4-cli to /repositories/341363356


## Extract State-Tactic Pairs

In [4]:
theorems = traced_repo.get_traced_theorems()
print(f"{len(theorems)} theorems/proofs extracted")

103234 theorems/proofs extracted


In [5]:
state_tactic_pairs = []

for thm in tqdm(theorems):
    for t in thm.get_traced_tactics():
        state_tactic_pairs.append({"state": t.state_before, "tactic": t.tactic})

print(f"{len(state_tactic_pairs)} state-tactic pairs")

100%|██████████████████████████████████████████████████████████████████████████| 103234/103234 [00:13<00:00, 7450.60it/s]

245127 state-tactic pairs





In [8]:
st = state_tactic_pairs[0]

In [9]:
print(st["state"])

α : Type u_1
β : Type u_2
ks : Array α
vs : Array β
h : Array.size ks = Array.size vs
i : Fin (Array.size ks)
j : Fin (Array.size vs)
k : α
v : β
⊢ Array.size (Array.set ks i k) = Array.size (Array.set vs j v)


In [10]:
print(st["tactic"])

simp [h]


## Finetuning Language Models for Tactic Generation

In [11]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/byt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

In [12]:
def tokenize(examples):
    model_inputs = tokenizer(examples["state"], max_length=2048, truncation=True)
    labels = tokenizer(text_target=examples["tactic"], max_length=2048, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


dataset = Dataset.from_list(state_tactic_pairs).shuffle().select(range(10000))
tokenized_dataset = dataset.map(tokenize, batched=True)

tokenized_dataset

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset({
    features: ['state', 'tactic', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 10000
})

In [13]:
# This is just an example.
# Don't run
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    learning_rate=1e-5,
    per_device_train_batch_size=8,
    max_steps=2,
    use_cpu=True,
)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss


TrainOutput(global_step=2, training_loss=4.278001308441162, metrics={'train_runtime': 214.2875, 'train_samples_per_second': 0.075, 'train_steps_per_second': 0.009, 'total_flos': 31754282262528.0, 'train_loss': 4.278001308441162, 'epoch': 0.0})

## Inspecting the Trained Tactic Generator

In [14]:
tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-tacgen-byt5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("kaiyuy/leandojo-lean4-tacgen-byt5-small")

In [15]:
type(model)

transformers.models.t5.modeling_t5.T5ForConditionalGeneration

In [16]:
type(tokenizer)

transformers.models.byt5.tokenization_byt5.ByT5Tokenizer

In [17]:
def generate_one_tactic(state: str) -> str:
    """Generate a single tactic."""
    tokenized_state = tokenizer(state, return_tensors="pt")
    tactic_ids = model.generate(tokenized_state.input_ids, max_length=1024)
    tactic = tokenizer.decode(tactic_ids[0], skip_special_tokens=True)
    print(tactic, end="\n\n")

In [18]:
generate_one_tactic("n : ℕ\n⊢ gcd n n = n")

rw [gcd_comm]



In [86]:
def generate_tactics(state: str, k: int = 8) -> List[str]:
    """Generate multiple tactics via beam search."""
    tokenized_state = tokenizer(state, return_tensors="pt")
    tactic_candidates_ids = model.generate(
        tokenized_state.input_ids,
        max_length=256,
        num_beams=k,
        length_penalty=0.0,
        do_sample=False,
        num_return_sequences=k,
        early_stopping=False,
    )
    tactic_candidates = tokenizer.batch_decode(
        tactic_candidates_ids, skip_special_tokens=True
    )
    return tactic_candidates

In [87]:
for tac in generate_tactics("n : ℕ\n⊢ gcd n n = n"):
    print(tac)

rw [gcd_comm]
induction' n with n IH
induction' n with n hn
cases n
rw [gcd]
induction' n with n ih
unfold gcd
rw [gcd_comm, gcd_gcd_self_right]


## Interacting with Lean

In [39]:
repo = LeanGitRepo(
    "https://github.com/yangky11/lean4-example",
    "f63c82aebf28d6b5d37d0eb043bca469fa156d57",
)
theorem = Theorem(repo, "Gcd.lean", "Hidden.gcd_self")

# For some theorems, it might take a few minutes.
dojo, s0 = Dojo(theorem).__enter__()



In [40]:
print(s0.pp)

n : ℕ
⊢ gcd n n = n


In [41]:
s1 = dojo.run_tac(s0, "revert n")

print(s1.pp)

⊢ ∀ (n : ℕ), gcd n n = n


In [42]:
s2 = dojo.run_tac(s1, "intro n")

print(s2.pp)

n : ℕ
⊢ gcd n n = n


In [43]:
s3 = dojo.run_tac(s1, "hello world!")

s3

LeanError(error='<stdin>:1:1: unknown tactic')

In [44]:
dojo.run_tac(s3, "skip")

RuntimeError: Attempting to run a tactic on an invalid state LeanError(error='<stdin>:1:1: unknown tactic').

In [45]:
s4 = dojo.run_tac(s0, "cases n <;> unfold gcd")

print(s4.pp)

case zero
⊢ zero = zero

case succ
n✝ : Nat
⊢ gcd (succ n✝ % (n✝ + 1)) (n✝ + 1) = succ n✝


In [46]:
s5 = dojo.run_tac(s4, "rfl")

print(s5.pp)

case succ
n✝ : Nat
⊢ gcd (succ n✝ % (n✝ + 1)) (n✝ + 1) = succ n✝


In [48]:
s6 = dojo.run_tac(s5, "rw [mod_self]")

print(s6.pp)

case succ
n✝ : Nat
⊢ gcd 0 (n✝ + 1) = succ n✝


In [50]:
s7 = dojo.run_tac(s6, "simp [gcd]")

s7

ProofFinished(tactic_state_id=8, message='')

## Proof Search

In [95]:
Tactic = str
Proof = List[Tactic]
num_candidates = 32
depth_limit = 5

def dfs(dojo : Dojo, state : TacticState, depth : int) -> Optional[Proof]:
    if depth > depth_limit:
        print("Hit the depth limit! Backtracking...")
        return None

    print(f"Current goal:\n{state.pp}\n")

    print("Generating tactics...")
    tactics = generate_tactics(state.pp, num_candidates)
    print(f"{num_candidates} tactic candidates:")
    print("\n".join(tactics) + "\n")
    
    for tac in tactics:
        next_state = dojo.run_tac(state, tac)
        if isinstance(next_state, ProofFinished):
            print("Found a proof!")
            return [tac]
        elif isinstance(next_state, LeanError):
            pass
        else:
            assert isinstance(next_state, TacticState)
            print(f"Applied tactic: {tac}\n")
            subproof = dfs(dojo, next_state, depth + 1)
            if subproof is not None:
                return [tac] + subproof
    
    print("Unable to prove the current goal. Backtracking...")
    return None


def search(thm : Theorem) -> Optional[Proof]:
    with Dojo(theorem) as (dojo, s0):
        return dfs(dojo, s0, 0)

In [96]:
print(theorem.full_name)

Hidden.gcd_self


In [97]:
proof = search(theorem)



Current goal:
n : ℕ
⊢ gcd n n = n

Generating tactics...
32 tactic candidates:
rw [gcd_comm]
induction' n with n IH
induction' n with n hn
cases n
rw [gcd]
induction' n with n ih
unfold gcd
rw [gcd_comm, gcd_gcd_self_right]
cases' n with n
rw [gcd, gcd_comm]
rcases n.eq_zero_or_pos with (rfl | hn)
simp [gcd]
dsimp [gcd]
simp [gcd_comm]
rw [gcd_comm, gcd_succ_self]
rfl
induction' n with n IH generalizing n
induction' n with n n_ih
rw [gcd_comm, gcd_eq_left_iff_dvd]
induction' n with n ihn
rw [gcd_comm, gcd_gcd_self_left]
rw [gcd_succ]
rw [gcd_comm, gcd_self_right]
by_cases hn : n = 0
rcases eq_or_ne n 0 with (rfl | hn)
rw [gcd, Nat.gcd_comm]
by_cases n0 : n = 0
simp
simp only [gcd_comm]
exact gcd_succ_self n
rw [gcd, gcd_succ_self]
apply gcd_succ_succ

Applied tactic: cases n

Current goal:
case zero
⊢ gcd zero zero = zero

case succ
n✝ : Nat
⊢ gcd (succ n✝) (succ n✝) = succ n✝

Generating tactics...
32 tactic candidates:
case zero => rfl
. rfl
case zero => simp
rfl
simp
. simp
case zer

In [99]:
print("\n".join(proof))

cases n
case zero => rfl
simp [gcd]


## Using the Model in Lean