NeurIPS 2023 Tutorial on Machine Learning for Theorem Proving
=============================================================

In [21]:
import torch
import random
import numpy as np
from tqdm import tqdm
from pprint import pprint
from datasets import Dataset
from transformers import (
  AutoModelForSeq2SeqLM,
  AutoTokenizer,
  Seq2SeqTrainer,
  Seq2SeqTrainingArguments,
  DataCollatorForSeq2Seq,
)
from typing import List, Dict, Optional

# https://arxiv.org/abs/2109.08203
random.seed(3407)
np.random.seed(3407)
torch.manual_seed(3407)

<torch._C.Generator at 0x7fe735ec1010>

## Roadmap

* Goal: Using language models to build a theorem prover that can run directly in Lean.
* Training the tactic generator
  * Use [**LeanDojo**](https://github.com/lean-dojo/LeanDojo) to extract data (state-tactic pairs) from mathlib.
  * Finetuning a language model for tactic generation
* Searching for proofs
  * Interact with Lean using [**LeanDojo**](https://github.com/lean-dojo/LeanDojo)
  * Proof search with DFS
* Using the model in Lean with [**Lean Copilot**](https://github.com/lean-dojo/LeanInfer)

## Data Extraction

We use **[LeanDojo](https://leandojo.org/)** to extract state-tactic pairs from mathlib.

### Trace the Repo

In [4]:
from lean_dojo import *

repo = LeanGitRepo(
    "https://github.com/leanprover-community/mathlib4",
    "3ce43c18f614b76e161f911b75a3e1ef641620ff",
)

# repo.show()

In [None]:
traced_repo = trace(repo)  # A few minutes, depending on #CPUs.

[32m2023-12-04 09:16:53.097[0m | [1mINFO    [0m | [36mlean_dojo.data_extraction.trace[0m:[36mtrace[0m:[36m182[0m - [1mLoading the traced repo from /home/kaiyu/.cache/lean_dojo/leanprover-community-mathlib4-3ce43c18f614b76e161f911b75a3e1ef641620ff/mathlib4[0m
2023-12-04 09:16:55,271	INFO worker.py:1664 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
100%|█████████████████████████████████████████████████████████████████| 4462/4462 [07:15<00:00, 10.24it/s]
Following Github server redirection from /repos/mhuisi/lean4-cli to /repositories/341363356


### Extract State-Tactic Pairs

`traced_repo` is a data structure containing all data extracted from `repo`. We can post-process it to extract state-tactic pairs.

In [5]:
theorems = traced_repo.get_traced_theorems()
print(f"{len(theorems)} theorems/proofs extracted")

103234 theorems/proofs extracted


In [6]:
state_tactic_pairs = []

for thm in tqdm(theorems):
  for t in thm.get_traced_tactics():
    state_tactic_pairs.append({
        "state": t.state_before, 
        "tactic": t.tactic
    })

print(f"{len(state_tactic_pairs)} state-tactic pairs")

100%|███████████████████████████████████████████████████████████| 103234/103234 [00:10<00:00, 9582.12it/s]

245127 state-tactic pairs





In [7]:
st = state_tactic_pairs[0]

In [8]:
print(st["state"])

α : Type u_1
β : Type u_2
ks : Array α
vs : Array β
h : Array.size ks = Array.size vs
i : Fin (Array.size ks)
j : Fin (Array.size vs)
k : α
v : β
⊢ Array.size (Array.set ks i k) = Array.size (Array.set vs j v)


In [9]:
print(st["tactic"])

simp [h]


## Finetuning Language Models for Tactic Generation

There are many excellent libraries that can be used for finetuning tactic generators (e.g., [Pytorch Lightning](https://lightning.ai/), [ReProver](https://github.com/lean-dojo/ReProver)). The code below is only an illustration of the process. DO NOT USE IT FOR PRODUCTION.

We finetune a [ByT5](https://arxiv.org/abs/2105.13626) model. It is a tokenization-free version of T5, with the same encoder-decoder Transformer architecture.

In [11]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/byt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

Let's pick a random subset of the data and look at one example.

In [22]:
dataset = Dataset.from_list(state_tactic_pairs).shuffle().select(range(10000))
pprint(dataset[0])

{'state': 'α : Type u_1\n'
          'β : Type u_2\n'
          'ι : Type u_3\n'
          'mα : MeasurableSpace α\n'
          'mβ : MeasurableSpace β\n'
          'γ : Type u_4\n'
          'mγ : MeasurableSpace γ\n'
          's✝ : Set (β × γ)\n'
          'κ : { x // x ∈ kernel α β }\n'
          'inst✝¹ : IsSFiniteKernel κ\n'
          'η : { x // x ∈ kernel (α × β) γ }\n'
          'inst✝ : IsSFiniteKernel η\n'
          'a✝ : α\n'
          's : Set β\n'
          't : Set γ\n'
          'hs : MeasurableSet s\n'
          'ht : MeasurableSet t\n'
          'a : α\n'
          'u : Set (β × γ)\n'
          'hu : MeasurableSet u\n'
          'b : β\n'
          '⊢ ↑↑(↑η (a, b)) {c | (b, c) ∈ u ∧ b ∈ s ∧ c ∈ t} = Set.indicator s '
          '(fun b => ↑↑(↑η (a, b)) ({c | (b, c) ∈ u} ∩ t)) b',
 'tactic': 'rw [Set.indicator_apply]'}


In [25]:
def tokenize(examples):
  model_inputs = tokenizer(examples["state"], max_length=2048, truncation=True)
  labels = tokenizer(text_target=examples["tactic"], max_length=2048, truncation=True)
  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

tokenized_dataset = dataset.map(tokenize, batched=True)

print(tokenized_dataset[0])

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

{'state': 'α : Type u_1\nβ : Type u_2\nι : Type u_3\nmα : MeasurableSpace α\nmβ : MeasurableSpace β\nγ : Type u_4\nmγ : MeasurableSpace γ\ns✝ : Set (β × γ)\nκ : { x // x ∈ kernel α β }\ninst✝¹ : IsSFiniteKernel κ\nη : { x // x ∈ kernel (α × β) γ }\ninst✝ : IsSFiniteKernel η\na✝ : α\ns : Set β\nt : Set γ\nhs : MeasurableSet s\nht : MeasurableSet t\na : α\nu : Set (β × γ)\nhu : MeasurableSet u\nb : β\n⊢ ↑↑(↑η (a, b)) {c | (b, c) ∈ u ∧ b ∈ s ∧ c ∈ t} = Set.indicator s (fun b => ↑↑(↑η (a, b)) ({c | (b, c) ∈ u} ∩ t)) b', 'tactic': 'rw [Set.indicator_apply]', 'input_ids': [209, 180, 35, 61, 35, 87, 124, 115, 104, 35, 120, 98, 52, 13, 209, 181, 35, 61, 35, 87, 124, 115, 104, 35, 120, 98, 53, 13, 209, 188, 35, 61, 35, 87, 124, 115, 104, 35, 120, 98, 54, 13, 112, 209, 180, 35, 61, 35, 80, 104, 100, 118, 120, 117, 100, 101, 111, 104, 86, 115, 100, 102, 104, 35, 209, 180, 13, 112, 209, 181, 35, 61, 35, 80, 104, 100, 118, 120, 117, 100, 101, 111, 104, 86, 115, 100, 102, 104, 35, 209, 181, 13, 209,

In [13]:
# This is just an example.
# Don't run
training_args = Seq2SeqTrainingArguments(
  output_dir="./results",
  learning_rate=1e-5,
  per_device_train_batch_size=8,
  max_steps=2,
  use_cpu=True,
)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

trainer = Seq2SeqTrainer(
  model=model,
  tokenizer=tokenizer,
  args=training_args,
  train_dataset=tokenized_dataset,
  data_collator=data_collator,
)

trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss


TrainOutput(global_step=2, training_loss=4.278001308441162, metrics={'train_runtime': 214.2875, 'train_samples_per_second': 0.075, 'train_steps_per_second': 0.009, 'total_flos': 31754282262528.0, 'train_loss': 4.278001308441162, 'epoch': 0.0})

## Inspecting the Trained Tactic Generator

In [26]:
tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-tacgen-byt5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("kaiyuy/leandojo-lean4-tacgen-byt5-small")

In [28]:
def generate_one_tactic(state: str) -> str:
  """Generate a single tactic."""
  tokenized_state = tokenizer(state, return_tensors="pt")
  tactic_ids = model.generate(tokenized_state.input_ids, max_length=1024)
  tactic = tokenizer.decode(tactic_ids[0], skip_special_tokens=True)
  print(tactic, end="\n\n")

In [29]:
generate_one_tactic("n : ℕ\n⊢ gcd n n = n")

rw [gcd_comm]



In [52]:
def generate_tactics(state: str, k: int = 8) -> List[str]:
  """Generate multiple tactics via beam search."""
  tokenized_state = tokenizer(state, return_tensors="pt")
  tactic_candidates_ids = model.generate(
    tokenized_state.input_ids,
    max_length=256,
    num_beams=k,
    length_penalty=0.0,
    do_sample=False,
    num_return_sequences=k,
    early_stopping=False,
  )
  tactic_candidates = tokenizer.batch_decode(
    tactic_candidates_ids, skip_special_tokens=True
  )
  return tactic_candidates

In [87]:
for tac in generate_tactics("n : ℕ\n⊢ gcd n n = n"):
  print(tac)

rw [gcd_comm]
induction' n with n IH
induction' n with n hn
cases n
rw [gcd]
induction' n with n ih
unfold gcd
rw [gcd_comm, gcd_gcd_self_right]


## Interacting with Lean

[**LeanDojo**](https://github.com/lean-dojo/LeanDojo) supports interacting with Lean in Python. We'll use the `gcd_self` theorem as an example.

![gcd_self.jpg](./gcd_self.jpg)

In [33]:
repo = LeanGitRepo(
    "https://github.com/yangky11/lean4-example",
    "f63c82aebf28d6b5d37d0eb043bca469fa156d57",
)
theorem = Theorem(repo, "Gcd.lean", "Hidden.gcd_self")

dojo, s0 = Dojo(theorem).__enter__()



In [34]:
print(s0.pp)

n : ℕ
⊢ gcd n n = n


In [37]:
s1 = dojo.run_tac(s0, "revert n")

print(s1.pp)

⊢ ∀ (n : ℕ), gcd n n = n


In [38]:
s2 = dojo.run_tac(s1, "hello world!")

s2

LeanError(error='<stdin>:1:1: unknown tactic')

In [39]:
dojo.run_tac(s2, "skip")

RuntimeError: Attempting to run a tactic on an invalid state LeanError(error='<stdin>:1:1: unknown tactic').

In [42]:
s3 = dojo.run_tac(s0, "cases n <;> unfold gcd")

print(s3.pp)

case zero
⊢ zero = zero

case succ
n✝ : Nat
⊢ gcd (succ n✝ % (n✝ + 1)) (n✝ + 1) = succ n✝


In [43]:
s4 = dojo.run_tac(s3, "rfl")

print(s4.pp)

case succ
n✝ : Nat
⊢ gcd (succ n✝ % (n✝ + 1)) (n✝ + 1) = succ n✝


In [45]:
s5 = dojo.run_tac(s4, "rw [mod_self]")

print(s5.pp)

case succ
n✝ : Nat
⊢ gcd 0 (n✝ + 1) = succ n✝


In [46]:
s6 = dojo.run_tac(s5, "simp [gcd]")

s6

ProofFinished(tactic_state_id=8, message='')

## Proof Search

We combine the tactic generator with Depth First Search (DFS) to search for proofs.

In [57]:
Tactic = str
Proof = List[Tactic]
num_candidates = 32
depth_limit = 5

def dfs(dojo : Dojo, state : TacticState, depth : int) -> Optional[Proof]:
    # Check the depth limit.
    if depth > depth_limit:
        print("Hit the depth limit! Backtracking...")
        return None

    # Generate tactic candidates.
    print(f"Current goal:\n{state.pp}\n")
    print("Generating tactics...")
    tactics = generate_tactics(state.pp, num_candidates)
    print(f"{num_candidates} tactic candidates:")
    print("\n".join(tactics) + "\n")

    # Running the generated tactics.
    for tac in tactics:
        next_state = dojo.run_tac(state, tac)
        if isinstance(next_state, ProofFinished):
            # Success!
            print("Found a proof!")
            return [tac]
        elif isinstance(next_state, LeanError):
            pass
        else:
            # Call `dfs` recursively.
            assert isinstance(next_state, TacticState)
            print(f"Applied tactic: {tac}\n")
            subproof = dfs(dojo, next_state, depth + 1)
            if subproof is not None:
                return [tac] + subproof
    
    print("Unable to prove the current goal. Backtracking...")
    return None


def search(thm : Theorem) -> Optional[Proof]:
    with Dojo(theorem) as (dojo, s0):
        return dfs(dojo, s0, 0)

In [54]:
print(theorem.full_name)

Hidden.gcd_self


In [58]:
proof = search(theorem)



Current goal:
n : ℕ
⊢ gcd n n = n

Generating tactics...
32 tactic candidates:
rw [gcd_comm]
induction' n with n IH
induction' n with n hn
cases n
rw [gcd]
induction' n with n ih
unfold gcd
rw [gcd_comm, gcd_gcd_self_right]
cases' n with n
rw [gcd, gcd_comm]
rcases n.eq_zero_or_pos with (rfl | hn)
simp [gcd]
dsimp [gcd]
simp [gcd_comm]
rw [gcd_comm, gcd_succ_self]
rfl
induction' n with n IH generalizing n
induction' n with n n_ih
rw [gcd_comm, gcd_eq_left_iff_dvd]
induction' n with n ihn
rw [gcd_comm, gcd_gcd_self_left]
rw [gcd_succ]
rw [gcd_comm, gcd_self_right]
by_cases hn : n = 0
rcases eq_or_ne n 0 with (rfl | hn)
rw [gcd, Nat.gcd_comm]
by_cases n0 : n = 0
simp
simp only [gcd_comm]
exact gcd_succ_self n
rw [gcd, gcd_succ_self]
apply gcd_succ_succ

Applied tactic: cases n

Current goal:
case zero
⊢ gcd zero zero = zero

case succ
n✝ : Nat
⊢ gcd (succ n✝) (succ n✝) = succ n✝

Generating tactics...
32 tactic candidates:
case zero => rfl
. rfl
case zero => simp
rfl
simp
. simp
case zer

In [59]:
print("\n".join(proof))

cases n
case zero => rfl
simp [gcd]


## Using the Model in Lean