NeurIPS 2023 Tutorial on Machine Learning for Theorem Proving
=============================================================

In [1]:
import torch
import random
import numpy as np
from tqdm import tqdm
from pprint import pprint
from transformers import (
  AutoModelForSeq2SeqLM,
  AutoTokenizer,
  Seq2SeqTrainer,
  Seq2SeqTrainingArguments,
  DataCollatorForSeq2Seq,
)
from datasets import Dataset
from typing import List, Dict, Optional

# https://arxiv.org/abs/2109.08203
random.seed(3407)
np.random.seed(3407)
torch.manual_seed(3407)

<torch._C.Generator at 0x7f0db42e7110>

## Roadmap

* Goal: Using language models to build a theorem prover in Lean.
* Training the tactic generator
  * Using [**LeanDojo**](https://github.com/lean-dojo/LeanDojo) to extract data (state-tactic pairs) from mathlib.
  * Finetuning a language model for tactic generation
* Searching for proofs
  * Interacting with Lean using [**LeanDojo**](https://github.com/lean-dojo/LeanDojo)
  * Proof search with DFS
* Using the model in Lean with [**Lean Copilot**](https://github.com/lean-dojo/LeanInfer)

## Data Extraction

We use **[LeanDojo](https://leandojo.org/)** to extract state-tactic pairs from mathlib.

### Trace the Repo

In [2]:
from lean_dojo import *

In [3]:
repo = LeanGitRepo(
    "https://github.com/leanprover-community/mathlib4",
    "3ce43c18f614b76e161f911b75a3e1ef641620ff",
)

In [4]:
traced_repo = trace(repo)  # A few minutes, depending on #CPUs.

[32m2023-12-08 18:49:54.915[0m | [1mINFO    [0m | [36mlean_dojo.data_extraction.trace[0m:[36mtrace[0m:[36m182[0m - [1mLoading the traced repo from /home/kaiyu/.cache/lean_dojo/leanprover-community-mathlib4-3ce43c18f614b76e161f911b75a3e1ef641620ff/mathlib4[0m
2023-12-08 18:49:57,134	INFO worker.py:1664 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
100%|██████████| 4462/4462 [07:40<00:00,  9.69it/s]  
Following Github server redirection from /repos/mhuisi/lean4-cli to /repositories/341363356


### Extract State-Tactic Pairs

`traced_repo` is a data structure containing all data extracted from `repo`. We can post-process it to extract state-tactic pairs.

In [5]:
theorems = traced_repo.get_traced_theorems()
print(f"{len(theorems)} theorems/proofs extracted")

103234 theorems/proofs extracted


In [6]:
state_tactic_pairs = []

for thm in tqdm(theorems):
  for t in thm.get_traced_tactics():
    state_tactic_pairs.append({
        "state": t.state_before, 
        "tactic": t.tactic
    })

print(f"{len(state_tactic_pairs)} state-tactic pairs")

100%|██████████| 103234/103234 [00:11<00:00, 9271.92it/s]

245127 state-tactic pairs





In [7]:
st = state_tactic_pairs[0]
print(st["state"])

α : Type u_1
β : Type u_2
ks : Array α
vs : Array β
h : Array.size ks = Array.size vs
i : Fin (Array.size ks)
j : Fin (Array.size vs)
k : α
v : β
⊢ Array.size (Array.set ks i k) = Array.size (Array.set vs j v)


In [8]:
print(st["tactic"])

simp [h]


## Finetuning Language Models for Tactic Generation

There are many excellent libraries that can be used for finetuning tactic generators (e.g., [Pytorch Lightning](https://lightning.ai/), [ReProver](https://github.com/lean-dojo/ReProver)). The code below is only an illustration of the process. DO NOT USE IT FOR PRODUCTION.

We finetune a [ByT5](https://arxiv.org/abs/2105.13626) model. It is a tokenization-free version of T5, with the same encoder-decoder Transformer architecture.

In [9]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/byt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

Let's pick a random subset of the data and look at one example.

In [10]:
dataset = Dataset.from_list(state_tactic_pairs).shuffle().select(range(10000))
pprint(dataset[0])

{'state': 'case pos\n'
          'Ω : Type u_1\n'
          'β : Type u_2\n'
          'ι : Type u_3\n'
          'm : MeasurableSpace Ω\n'
          'inst✝ : Preorder ι\n'
          'f : Filtration ι m\n'
          'τ π : Ω → ι\n'
          'hτ : IsStoppingTime f τ\n'
          's : Set Ω\n'
          'i : ι\n'
          'this : ∀ (j : ι), {ω | τ ω = i} ∩ {ω | τ ω ≤ j} = {ω | τ ω = i} ∩ '
          '{_ω | i ≤ j}\n'
          'h : MeasurableSet (s ∩ {ω | τ ω = i})\n'
          'j : ι\n'
          'hij : i ≤ j\n'
          '⊢ MeasurableSet (s ∩ ({ω | τ ω = i} ∩ {_ω | i ≤ j}))',
 'tactic': 'simp only [hij, Set.setOf_true, Set.inter_univ]'}


In [11]:
def tokenize(examples):
  model_inputs = tokenizer(examples["state"], max_length=2048, truncation=True)
  labels = tokenizer(text_target=examples["tactic"], max_length=2048, truncation=True)
  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

tokenized_dataset = dataset.map(tokenize, batched=True)

print(tokenized_dataset[0])

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

{'state': 'case pos\nΩ : Type u_1\nβ : Type u_2\nι : Type u_3\nm : MeasurableSpace Ω\ninst✝ : Preorder ι\nf : Filtration ι m\nτ π : Ω → ι\nhτ : IsStoppingTime f τ\ns : Set Ω\ni : ι\nthis : ∀ (j : ι), {ω | τ ω = i} ∩ {ω | τ ω ≤ j} = {ω | τ ω = i} ∩ {_ω | i ≤ j}\nh : MeasurableSet (s ∩ {ω | τ ω = i})\nj : ι\nhij : i ≤ j\n⊢ MeasurableSet (s ∩ ({ω | τ ω = i} ∩ {_ω | i ≤ j}))', 'tactic': 'simp only [hij, Set.setOf_true, Set.inter_univ]', 'input_ids': [102, 100, 118, 104, 35, 115, 114, 118, 13, 209, 172, 35, 61, 35, 87, 124, 115, 104, 35, 120, 98, 52, 13, 209, 181, 35, 61, 35, 87, 124, 115, 104, 35, 120, 98, 53, 13, 209, 188, 35, 61, 35, 87, 124, 115, 104, 35, 120, 98, 54, 13, 112, 35, 61, 35, 80, 104, 100, 118, 120, 117, 100, 101, 111, 104, 86, 115, 100, 102, 104, 35, 209, 172, 13, 108, 113, 118, 119, 229, 159, 160, 35, 61, 35, 83, 117, 104, 114, 117, 103, 104, 117, 35, 209, 188, 13, 105, 35, 61, 35, 73, 108, 111, 119, 117, 100, 119, 108, 114, 113, 35, 209, 188, 35, 112, 13, 210, 135, 35, 2

In [12]:
# This is just an example. Don't run it.
training_args = Seq2SeqTrainingArguments(
  output_dir="./results",
  learning_rate=1e-5,
  per_device_train_batch_size=8,
  max_steps=2,
  use_cpu=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

trainer = Seq2SeqTrainer(
  model=model,
  tokenizer=tokenizer,
  args=training_args,
  train_dataset=tokenized_dataset,
  data_collator=data_collator,
)

trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss


TrainOutput(global_step=2, training_loss=8.766227722167969, metrics={'train_runtime': 42.3437, 'train_samples_per_second': 0.378, 'train_steps_per_second': 0.047, 'total_flos': 25093347827712.0, 'train_loss': 8.766227722167969, 'epoch': 0.0})

## Inspecting the Trained Tactic Generator

In [12]:
tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-tacgen-byt5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("kaiyuy/leandojo-lean4-tacgen-byt5-small")

In [13]:
def generate_one_tactic(state: str) -> str:
  """Generate a single tactic."""
  tokenized_state = tokenizer(state, return_tensors="pt")
  tactic_ids = model.generate(tokenized_state.input_ids, max_length=1024)
  tactic = tokenizer.decode(tactic_ids[0], skip_special_tokens=True)
  print(tactic, end="\n\n")

In [15]:
generate_one_tactic("∀ (a b c : ℕ), a + b + c = a + c + b")

intro a b c



In [17]:
def generate_tactics(state: str, k: int = 8) -> List[str]:
  """Generate multiple tactics via beam search."""
  tokenized_state = tokenizer(state, return_tensors="pt")
  tactic_candidates_ids = model.generate(
    tokenized_state.input_ids,
    max_length=256,
    num_beams=k,
    length_penalty=0.0,
    do_sample=False,
    num_return_sequences=k,
    early_stopping=False,
  )
  tactic_candidates = tokenizer.batch_decode(
    tactic_candidates_ids, skip_special_tokens=True
  )
  return tactic_candidates

In [18]:
for tac in generate_tactics("n : ℕ\n⊢ gcd n n = n"):
  print(tac)

rw [gcd_comm]
induction' n with n IH
induction' n with n hn
cases n
rw [gcd]
induction' n with n ih
unfold gcd
rw [gcd_comm, gcd_gcd_self_right]


## Interacting with Lean

[**LeanDojo**](https://github.com/lean-dojo/LeanDojo) supports interacting with Lean in Python. We'll use the `gcd_self` theorem as an example.

![add_abc.jpg](./add_abc.jpg)

In [18]:
repo = LeanGitRepo(
    "https://github.com/yangky11/lean4-example",
    "5117c1d326f0505bef137c7c99099f3f780624b9",
)
theorem = Theorem(repo, "Lean4Example.lean", "add_abc")

dojo, s0 = Dojo(theorem).__enter__()



In [19]:
print(s0.pp)

n : ℕ
⊢ gcd n n = n


In [20]:
s1 = dojo.run_tac(s0, "revert n")

print(s1.pp)

⊢ ∀ (n : ℕ), gcd n n = n


In [21]:
s2 = dojo.run_tac(s1, "hello world!")

s2

LeanError(error='<stdin>:1:1: unknown tactic')

In [22]:
dojo.run_tac(s2, "skip")

RuntimeError: Attempting to run a tactic on an invalid state LeanError(error='<stdin>:1:1: unknown tactic').

In [None]:
s3 = dojo.run_tac(s0, "cases n <;> unfold gcd")

print(s3.pp)

In [None]:
s4 = dojo.run_tac(s3, "rfl")

print(s4.pp)

In [None]:
s5 = dojo.run_tac(s4, "rw [mod_self]")

print(s5.pp)

In [None]:
s6 = dojo.run_tac(s5, "simp [gcd]")

s6

## Proof Search

We combine the tactic generator with Depth First Search (DFS) to search for proofs.

In [None]:
Tactic = str
Proof = List[Tactic]
num_candidates = 32
depth_limit = 5

def dfs(dojo : Dojo, state : TacticState, depth : int) -> Optional[Proof]:
    # Check the depth limit.
    if depth > depth_limit:
        print("Hit the depth limit! Backtracking...")
        return None

    # Generate tactic candidates.
    print(f"Current goal:\n{state.pp}\n")
    print("Generating tactics...")
    tactics = generate_tactics(state.pp, num_candidates)
    print(f"{num_candidates} tactic candidates:")
    print("\n".join(tactics) + "\n")

    # Running the generated tactics.
    for tac in tactics:
        next_state = dojo.run_tac(state, tac)
        if isinstance(next_state, ProofFinished):
            # Success!
            print("Found a proof!")
            return [tac]
        elif isinstance(next_state, LeanError):
            pass
        else:
            # Call `dfs` recursively.
            assert isinstance(next_state, TacticState)
            print(f"Applied tactic: {tac}\n")
            subproof = dfs(dojo, next_state, depth + 1)
            if subproof is not None:
                return [tac] + subproof
    
    print("Unable to prove the current goal. Backtracking...")
    return None


def search(thm : Theorem) -> Optional[Proof]:
    with Dojo(theorem) as (dojo, s0):
        return dfs(dojo, s0, 0)

In [None]:
print(theorem.full_name)

In [None]:
proof = search(theorem)

In [None]:
print("\n".join(proof))

## Using the Model in Lean