# 1. Initialize the traced repo
We need to trace the Mathlib first so that we can interact with theorems within it.

In [1]:
import os
import pickle
from lean_dojo import LeanGitRepo, trace, Dojo, Theorem, LeanError, TacticState, data_extraction

In [3]:
repo = LeanGitRepo(
    "https://github.com/leanprover-community/mathlib4",
    "3c307701fa7e9acbdc0680d7f3b9c9fed9081740"   # We use this commit in our paper
)

In [3]:
cache_path = "./data/candidate_theorems.pickle"
if not os.path.exists(cache_path):
    traced_repo = trace(repo)
    candidate_theorems = [thm for thm in traced_repo.get_traced_theorems() if not thm.repo.is_lean4]
    pickle.dump(candidate_theorems, open(cache_path, "wb"))
else:
    candidate_theorems = pickle.load(open(cache_path, "rb"))

In [4]:
# initialize the dojo env for a target theorem
# e.g. the demo theorem in the Fig.1 of our paper
demo_path = "./data/demo_theorem.pickle"  # we provide a cached demo theorems
demo_theorem = pickle.load(open(demo_path, "rb"))



In [None]:
# initialize the dojo env
demo_name = demo_theorem.theorem.full_name
file_path = demo_theorem.theorem.file_path
theorem = Theorem(repo, file_path, demo_name)
cwd = os.getcwd()
dojo, state_0 = Dojo(theorem).__enter__()

In [5]:
ast = demo_theorem.ast
print(ast.lean_file[ast.start:ast.end])

theorem Nat.disjoint_divisors_filter_isPrimePow {a b : ℕ} (hab : a.Coprime b) :
    Disjoint (a.divisors.filter IsPrimePow) (b.divisors.filter IsPrimePow) := by
  simp only [Finset.disjoint_left, Finset.mem_filter, and_imp, Nat.mem_divisors, not_and]
  rintro n han _ha hn hbn _hb -
  exact hn.ne_one (Nat.eq_one_of_dvd_coprimes hab han hbn)


# 2. Find Invocable Theorems
We find the invocable theorems by running specific tactics.

In [14]:
os.chdir(cwd) # The initialization of dojo will change the current working directory, so we need to change it back
import tqdm
from _parse import extract_id_decl_proof, parse_hypothesis
from _logic import get_rw_insts, get_apply_insts
from _utils import load_jsonl, save_jsonl
from lean_dojo.interaction.dojo import DojoHardTimeoutError

**To check whether a specific theorem is invocable**
1. Get the statement and hypothesis of the demo theorem. 
2. Construct the tactics that should be run.
3. Run the tactic and check the returning messgae.

In [16]:
nodes = extract_id_decl_proof(demo_theorem)
statement_node = nodes[1]
hypothesis = parse_hypothesis(statement_node)
hypo_names, _, hypo_strs = hypothesis

In [48]:
# rw example
p_t_name = "Nat.coprime_iff_isRelPrime"
rw_insts = get_rw_insts(p_t_name, hypo_names)
print(rw_insts)
assert rw_insts[2] == 'rw [Nat.coprime_iff_isRelPrime] at hab'
next_state = dojo.run_tac(state_0, rw_insts[2])
print("-"*20)
print(state_0.pp)
print("-"*20)
print(next_state.pp) # the hypothesis is mutated (invocable)

['rw [Nat.coprime_iff_isRelPrime]', 'rw [←Nat.coprime_iff_isRelPrime]', 'rw [Nat.coprime_iff_isRelPrime] at hab', 'rw [←Nat.coprime_iff_isRelPrime] at hab']
--------------------
R : Type u_1
inst✝ : CommMonoidWithZero R
n p : R
k a b : ℕ
hab : Coprime a b
⊢ Disjoint (Finset.filter IsPrimePow (divisors a)) (Finset.filter IsPrimePow (divisors b))
--------------------
R : Type u_1
inst✝ : CommMonoidWithZero R
n p : R
k a b : ℕ
hab : IsRelPrime a b
⊢ Disjoint (Finset.filter IsPrimePow (divisors a)) (Finset.filter IsPrimePow (divisors b))


In [17]:
# apply example
p_t_name = "Nat.coprime_of_mul_modEq_one"
apply_insts = get_apply_insts(p_t_name, hypo_strs)
print(apply_insts)
assert apply_insts[0] == 'have hab : a.Coprime b := by apply Nat.coprime_of_mul_modEq_one'
next_state = dojo.run_tac(state_0, apply_insts[0])
print("-"*20)
print(state_0.pp)
print("-"*20)
print(next_state.error) # produce new unsolved subgoals (invocable)

['have hab : a.Coprime b := by apply Nat.coprime_of_mul_modEq_one']
--------------------
R : Type u_1
inst✝ : CommMonoidWithZero R
n p : R
k a b : ℕ
hab : Coprime a b
⊢ Disjoint (Finset.filter IsPrimePow (divisors a)) (Finset.filter IsPrimePow (divisors b))
--------------------
unsolved goals
case h
R : Type u_1
inst✝ : CommMonoidWithZero R
n p : R
k a b : ℕ
hab : Coprime a b
⊢ a * ?b ≡ 1 [MOD b]

case b
R : Type u_1
inst✝ : CommMonoidWithZero R
n p : R
k a b : ℕ
hab : Coprime a b
⊢ ℕ


We implement the logic for finding new theorems in one api:

- **get_invocable_theorems_with_dojo**

*In practice, to support data-synthesis on multiple cpus, we run the data_generator.py on each cpu node.*

In [45]:
from _logic import get_invocable_theorems_with_dojo

In [8]:
# find invocable theorems using apply
possibly_invocable_theorems = candidate_theorems
rw_theorems = get_invocable_theorems_with_dojo(
    demo_theorem,dojo,state_0, possibly_invocable_theorems, mode="rw"
)
save_jsonl(rw_theorems[1], "./outputs/rw_invocable_theorems_demo.jsonl")

Checking invocable theorems:   3%|▎         | 3947/116695 [00:35<15:36, 120.38it/s]

Std.Tactic.Ext.$extName
Std.Tactic.Ext.$extIffName


Checking invocable theorems:  45%|████▌     | 52802/116695 [06:51<07:15, 146.84it/s]

Lean.Elab.Command.$n_def


Checking invocable theorems: 100%|██████████| 116695/116695 [14:49<00:00, 131.16it/s]


In [13]:
apply_theorems = get_invocable_theorems_with_dojo(
    demo_theorem,dojo,state_0, possibly_invocable_theorems, mode="apply"
)
save_jsonl(apply_theorems[1], "./outputs/apply_invocable_theorems_demo.jsonl")

Checking invocable theorems:   3%|▎         | 3966/116695 [00:14<05:29, 342.03it/s]

Std.Tactic.Ext.$extName
Std.Tactic.Ext.$extIffName


Checking invocable theorems:  46%|████▌     | 53099/116695 [01:15<00:46, 1357.54it/s]

Lean.Elab.Command.$n_def


Checking invocable theorems: 100%|██████████| 116695/116695 [02:18<00:00, 840.47it/s] 


# 3. Mutate and Verify the Original Theorems

## 3.1 Example
We implement the process to mutate the target theorem for an invocable theorem into two apis:
- **modify_theorem_rw**
- **modify_theorem_apply**

What should be noted is that implementations of these two apis are not perfect which can be further improved.

In [9]:
os.chdir(cwd)
from _modify import modify_theorem_rw, modify_theorem_apply

In [23]:
# For example
rw_invocable_inst = {
    "init_state": "R : Type u_1\ninst✝ : CommMonoidWithZero R\nn p : R\nk a b : ℕ\nhab : Coprime a b\n⊢ Disjoint (Finset.filter IsPrimePow (divisors a)) (Finset.filter IsPrimePow (divisors b))",
    "rule": "rw [Nat.coprime_iff_isRelPrime] at hab",
    "next_state": "R : Type u_1\ninst✝ : CommMonoidWithZero R\nn p : R\nk a b : ℕ\nhab : IsRelPrime a b\n⊢ Disjoint (Finset.filter IsPrimePow (divisors a)) (Finset.filter IsPrimePow (divisors b))",
}
print("*"*40 + "rw variant" + "*"*40)
print(modify_theorem_rw(demo_theorem, rw_invocable_inst)) # Equlity-Variant in Fig-1

apply_invocable_inst = {
    "init_state": "R : Type u_1\ninst✝ : CommMonoidWithZero R\nn p : R\nk a b : ℕ\nhab : Coprime a b\n⊢ Disjoint (Finset.filter IsPrimePow (divisors a)) (Finset.filter IsPrimePow (divisors b))",
    "rule": "have hab : a.Coprime b := by apply Nat.coprime_of_mul_modEq_one",
    "next_state": "unsolved goals\ncase h\nR : Type u_1\ninst✝ : CommMonoidWithZero R\nn p : R\nk a b : ℕ\nhab : Coprime a b\n⊢ a * ?b ≡ 1 [MOD b]\n\ncase b\nR : Type u_1\ninst✝ : CommMonoidWithZero R\nn p : R\nk a b : ℕ\nhab : Coprime a b\n⊢ ℕ"
}

print("*"*40 + "apply variant" + "*"*40)
print(modify_theorem_apply(demo_theorem, apply_invocable_inst)) # Apply-Variant in Fig-1

****************************************rw variant****************************************
example {a b : ℕ} (hab : IsRelPrime a b) :
    Disjoint (a.divisors.filter IsPrimePow) (b.divisors.filter IsPrimePow):= by
  have hab : a.Coprime b := by rw [←Nat.coprime_iff_isRelPrime] at hab;exact hab
  simp only [Finset.disjoint_left, Finset.mem_filter, and_imp, Nat.mem_divisors, not_and]
  rintro n han _ha hn hbn _hb -
  exact hn.ne_one (Nat.eq_one_of_dvd_coprimes hab han hbn)
****************************************apply variant****************************************
example {a b : ℕ} (g : ℕ) (h : a * g ≡ 1 [MOD b]) :
    Disjoint (a.divisors.filter IsPrimePow) (b.divisors.filter IsPrimePow):= by
  have hab : a.Coprime b := by apply Nat.coprime_of_mul_modEq_one<;> assumption
  simp only [Finset.disjoint_left, Finset.mem_filter, and_imp, Nat.mem_divisors, not_and]
  rintro n han _ha hn hbn _hb -
  exact hn.ne_one (Nat.eq_one_of_dvd_coprimes hab han hbn)


## 3.2 Mutate 
In practice, for rw and apply, we mutate the theorems and write variants back to the original Lean file (record the line number of each theorem). 

In [72]:
from _logic import filtering_single_goal

In [73]:
# for multiple target theorems-case, just change the file2theorem and theorem2output
# GEN_MODE = "rw"
GEN_MODE = "apply"
file_path = demo_theorem.theorem.file_path.as_posix()
file2theorem = {file_path: [demo_theorem]} # path : list of traced theorems
theorem2output ={demo_theorem: f"./outputs/{GEN_MODE}_invocable_theorems_demo.jsonl"}# target_theorem : path2invocable_theorems

In [None]:
# for each traced_file find the related traced_theorems and find the start and end of this theorems
# for each target_theorem in this file, modify the theorem and get modified theorems 
# insert all modified theorems of all traced theorems in this file at once
lean_corpus = []
test_files = list(file2theorem.items())
for traced_file, target_theorems in tqdm.tqdm(test_files, total=len(test_files)):
    modifier_blackboard = {}
    for target_theorem in target_theorems:
        # locate this theorem
        theorem_ast = target_theorem.ast
        leanfile = theorem_ast.lean_file
        start, end = theorem_ast.start, theorem_ast.end
        # get the output_file
        output_file = theorem2output[target_theorem]
        # for this target theorem get its modified version
        try:
            invocable_theorems = list(load_jsonl(output_file))
        except:
            print(f"Error in loading the file {output_file}")
        else:
            if GEN_MODE == "rw":
                filtered_invocable_theorems = filtering_single_goal(invocable_theorems)
            elif GEN_MODE == "apply":
                filtered_invocable_theorems = invocable_theorems
            new_theorems = []
            for idx, invocable_theorem in enumerate(filtered_invocable_theorems):
                insts = list(invocable_theorem.values())[0]
                for inst in insts:
                    if GEN_MODE == "rw":
                        new_theorem = modify_theorem_rw(target_theorem, inst)
                    elif GEN_MODE == "apply":
                        try:
                            new_theorem = modify_theorem_apply(target_theorem, inst)
                        except:
                            new_theorem = None
                    if new_theorem:
                        new_theorems.append(new_theorem)

        if not modifier_blackboard.get(target_theorem, 0) and len(new_theorems) > 0:
            modifier_blackboard[target_theorem] = {"start":start, "end": end, "new_theorems": new_theorems}

    # replace this file all at once
    res = dict()
    res['file_name'] = traced_file
    res['loc'] = dict() # e.g. {theorem_name: [(start, end), (start, end)]}
    offset = 0
    old_file = '\n'.join(leanfile.code)
    
    if len(modifier_blackboard.keys()) != 0:
        for target_theorem, modifier in modifier_blackboard.items():
            name = target_theorem.theorem.full_name
            res['loc'][name] = []
            start, end, new_theorems = modifier["start"], modifier["end"], modifier["new_theorems"]
            original_theorem = leanfile[start:end]
            start = start.line_nb + offset
            end = end.line_nb + offset 
            original_end = end
            # get the location of each modified theorems
            for idx, new_theorem in enumerate(new_theorems):
                if idx == 0:
                    theorem_start = end + 3
                else:
                    theorem_start = end + 2
                theorem_end = theorem_start + len(new_theorem.split('\n')) - 1
                end = theorem_end
                res['loc'][name].append((theorem_start, theorem_end))
            offset += theorem_end - original_end
            new_theorem = original_theorem +"\n" + "\n\n" + "\n\n".join(new_theorems)
            new_file = old_file.replace(original_theorem, new_theorem)
            old_file = new_file
        res['text'] = new_file
    else:
        res['text'] = old_file
        res['loc'] = ['No variants']
    lean_corpus.append(res)

print(f"The file number of lean_corpus is {len(lean_corpus)}")

100%|██████████| 1/1 [00:00<00:00, 13.62it/s]

The file number of lean_corpus is 1





In [75]:
save_jsonl(lean_corpus, f"./outputs/{GEN_MODE}_variants_demo_without_verify.jsonl") # save the modified theorems

## 3.3 Verify
We build the new repo and remove the wrong variants
1. We need to clone the Mathlib4.
2. Write then Build

In [76]:
# clone the mathlib (run once)
# !git clone https://github.com/leanprover-community/mathlib4

In [77]:
# !cd mathlib4 && git checkout 3c307701fa7e9acbdc0680d7f3b9c9fed9081740

In [78]:
import re
import numpy as np
from verify import lake_build, generate_text

*Below code is largely modified from verify.py, which is a distributed script used to verify synthesized theorems*

In [79]:
mathlib_package_path = "./mathlib4"
output_path = f"outputs/verify/{GEN_MODE}"
if not os.path.exists(output_path):
    os.makedirs(output_path)

PACKAGE_NAME = "Mathlib"
LIBRARY_ROOT_FILE = os.path.join(mathlib_package_path, PACKAGE_NAME + '.lean')
corpus = lean_corpus
corpus_path = f"./outputs/{GEN_MODE}_variants_demo_without_verify.jsonl"
files = list(load_jsonl(corpus_path))
print(f"There are {len(files)} files in total.")
res = []
for idx, file in enumerate(tqdm.tqdm(files, total=len(files))):
    file_name = file['file_name']
    file_path = os.path.join(mathlib_package_path, file_name)
    # extract the old content of this file
    with open(file_path, "r") as f:
        old_str = f.read()
    # replace the ola content with new content
    with open(file_path, "w") as f:
        f.write(file['text'])
    # change the build target to current file
    with open(LIBRARY_ROOT_FILE, 'w') as f:
        module_name = file_name.replace('/', '.').replace('.lean', '')
        f.write(f"import {module_name}")
    # intialize the result
    tmp = dict()
    tmp['file_name'] = file_name
    tmp['original_text'] = old_str
    tmp['text'] = file['text']
    tmp['loc'] = file['loc']
    tmp['valid_loc'] = []

    if file['loc'] != ['No variants']: # if there exists variants in this file
        ## lake build the new mathlib project
        wd = os.getcwd()
        result = lake_build(mathlib_package_path)
        os.chdir(wd)
        print(f"lake build the rewrited file {file_name}")
        ## parse the output
        if result == None:
            tmp['valid_loc'] = ["subprocess error"]
        elif result == 0:
            tmp['valid_loc'] = tmp['loc']
            print('successful build')
        elif result == -1:
            tmp['valid_loc'] = ["build timeout error"]
        else:
            # find the error locations(line numbers)
            pattern = fr"({file_name}):(\d+):(\d+): error:"
            errors = re.findall(pattern, result)
            if len(errors) == 0:
                tmp['valid_loc'] = ['parse error']
            else:
                error_line_nbs = []
                for error in errors:
                    _, line_nb, _ = error
                    error_line_nbs.append(int(line_nb))
                error_line_nbs = sorted(set(error_line_nbs))
                error_line_nbs = np.array(error_line_nbs)

                # get the locations of all variants
                intervals = []
                for t, locs in file['loc'].items():
                    intervals.extend(locs)
                intervals = np.array(intervals)
                if len(intervals) != 0:
                    # Create a boolean array of the same shape as error_line_nbs, 
                    # with True at positions where the corresponding element of error_line_nbs is in any interval
                    in_interval = (intervals[:, 0, np.newaxis] <= error_line_nbs) & (error_line_nbs <= intervals[:, 1, np.newaxis])
                    valid_intervals = intervals[in_interval.sum(axis=1)==0]
                    with open(file_path, "w") as f:
                        f.write(generate_text(file['text'], valid_intervals, intervals))
                    ## rebuilt the project if trigger error break
                    wd = os.getcwd()
                    result = lake_build(mathlib_package_path)
                    os.chdir(wd)
                    print(f"lake rebuild the rewrited file {file_name}")

                    if result != 0:
                        print("rebuild exception")
                        tmp['valid_loc'] = ['rebuild error']
                    else:
                        tmp['valid_loc'] = valid_intervals.tolist()
                else:
                    pass
    else:
        tmp['valid_loc'] = ['No variants']

    # write back the original content
    with open(file_path, "w") as f:
        f.write(old_str)
    res.append(tmp) 
save_jsonl(res, os.path.join(output_path, f"verified_results_{GEN_MODE}_demo.jsonl"))

There are 1 files in total.


  0%|          | 0/1 [00:00<?, ?it/s]

build error
lake build the rewrited file Mathlib/Algebra/IsPrimePow.lean


100%|██████████| 1/1 [00:18<00:00, 18.00s/it]

lake rebuild the rewrited file Mathlib/Algebra/IsPrimePow.lean





In [None]:
# Post-processing the verified results
for idx, result in enumerate(res):
    tmp = dict()
    # if not empty or exception
    if isinstance(result['valid_loc'], dict):
        result['formatted_loc'] = result['loc']
    if isinstance(result['valid_loc'], list):
        if result['valid_loc'] not in [["subprocess error"], ["build timeout error"], ["parse error"], ["rebuild error"], ["No variants"]]:
            # for each changed theorem and locs
            for ts, locs in result['loc'].items():
                tmp[ts] = []
                for loc in locs:
                    if loc in result['valid_loc']:
                        tmp[ts].append(loc)
    res[idx]['formatted_loc'] = tmp

results = [r for r in res if r['formatted_loc'] and any(r['formatted_loc'].values())]

outputs = []
for line in res:
    tmp = {
        'file_name': line['file_name'],
        'original_text': line['original_text'],
        'text': line['text'],
        'valid_loc': line['formatted_loc'],
        'loc': line['loc'],
        'meta': "https://github.com/leanprover-community/mathlib4/commit/3c307701fa7e9acbdc0680d7f3b9c9fed9081740"
    }
    outputs.append(tmp)

In [82]:
save_jsonl(outputs, f"./outputs/synthesized_corpus_{GEN_MODE}_demo.jsonl")

*Just change the GEN_MODE, you can synthesize theorems using different tactics*

## 3.4 Display the variants 
Now we have synthesized variants for the demo theorems through rw and apply.
Let us take a breif look at them.

In [83]:
rw_corpus = list(load_jsonl(f"./outputs/synthesized_corpus_rw_demo.jsonl"))
apply_corpus = list(load_jsonl(f"./outputs/synthesized_corpus_apply_demo.jsonl"))

*We only synthesize one variant for the demo theorem through rw, which is the example in Fig 1 of our paper*

In [85]:
text = rw_corpus[0]['text']
valid_locs = rw_corpus[0]['valid_loc']
for t, vs in valid_locs.items():
    for idx, v in enumerate(vs):
        print(f"Variant {idx}")
        print('\n'.join(text.split('\n')[v[0]-1:v[1]]))

Variant 0
example {a b : ℕ} (hab : IsRelPrime a b) :
    Disjoint (a.divisors.filter IsPrimePow) (b.divisors.filter IsPrimePow):= by
  have hab : a.Coprime b := by rw [←Nat.coprime_iff_isRelPrime] at hab;exact hab
  simp only [Finset.disjoint_left, Finset.mem_filter, and_imp, Nat.mem_divisors, not_and]
  rintro n han _ha hn hbn _hb -
  exact hn.ne_one (Nat.eq_one_of_dvd_coprimes hab han hbn)


*We synthesizes 5 variants for the demo theorem through apply, which includes the example in Fig 1*

In [86]:
text = apply_corpus[0]['text']
valid_locs = apply_corpus[0]['valid_loc']
for t, vs in valid_locs.items():
    for idx, v in enumerate(vs):
        print(f"Variant {idx}")
        print('\n'.join(text.split('\n')[v[0]-1:v[1]]))

Variant 0
example {a b : ℕ} (g : b ∈ []) :
    Disjoint (a.divisors.filter IsPrimePow) (b.divisors.filter IsPrimePow):= by
  have hab : a.Coprime b := by apply List.forall_mem_nil; assumption
  simp only [Finset.disjoint_left, Finset.mem_filter, and_imp, Nat.mem_divisors, not_and]
  rintro n han _ha hn hbn _hb -
  exact hn.ne_one (Nat.eq_one_of_dvd_coprimes hab han hbn)
Variant 1
example {a b : ℕ} (r : ℕ) (h : a * r ≡ 1 [MOD b]) :
    Disjoint (a.divisors.filter IsPrimePow) (b.divisors.filter IsPrimePow):= by
  have hab : a.Coprime b := by apply Nat.coprime_of_mul_modEq_one<;> assumption
  simp only [Finset.disjoint_left, Finset.mem_filter, and_imp, Nat.mem_divisors, not_and]
  rintro n han _ha hn hbn _hb -
  exact hn.ne_one (Nat.eq_one_of_dvd_coprimes hab han hbn)
Variant 2
example {a b : ℕ} (s : Set ℕ) (h : Set.Subsingleton s) (d : a ∈ s) (d : b ∈ s) (d : a ≠ b) :
    Disjoint (a.divisors.filter IsPrimePow) (b.divisors.filter IsPrimePow):= by
  have hab : a.Coprime b := by apply Set.

# 4. Construct the training data
Since we have synthesized variants for target theorem and recorded their locations. We can construct training data for:
1. Continual Pretrain
2. Supervised Fine-tuning
   
To retrieve the state-tactic pairs, we need to
1. Write back all variants to original repo.
2. Use Leandojo to trace this repo.
3. Parse the resulting ASTs to extract the additional state-tactic pairs.
   
*For the extract stage in 3., we provide a multi-process implementation in the extract_state_tactic.py*