diff --git a/pyproject.toml b/pyproject.toml index 60dc23f..0833849 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ build-backend = "hatchling.build" [project] name = "itp_interface" -version = "1.1.9" +version = "1.1.10" authors = [ { name="Amitayush Thakur", email="amitayush@utexas.edu" }, ] diff --git a/src/data/test/lean4_proj/Lean4Proj/Basic.lean b/src/data/test/lean4_proj/Lean4Proj/Basic.lean index 17c8024..a77254e 100644 --- a/src/data/test/lean4_proj/Lean4Proj/Basic.lean +++ b/src/data/test/lean4_proj/Lean4Proj/Basic.lean @@ -1,3 +1,4 @@ +import Mathlib namespace Lean4Proj1 def hello := "world" @@ -51,4 +52,26 @@ theorem test3 (p q : Prop) (hp : p) (hq : q) exact hq exact hp +theorem imo_1959_p1 + (n : ℕ) + (h₀ : 0 < n) : + Nat.gcd (21*n + 4) (14*n + 3) = 1 := by +rw [Nat.gcd_rec] +rw [Nat.mod_eq_of_lt (by linarith)] +rw [Nat.gcd_rec] +rw [Nat.gcd_rec] +have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by + have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring + rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add] + have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith + rw [Nat.mod_eq_of_lt] + rw [Nat.mod_eq_of_lt] + exact h₂ + rw [Nat.mod_eq_of_lt] + exact h₂ + exact h₂ +rw [eq₂] +sorry + + end Lean4Proj2 diff --git a/src/itp_interface/tools/lean4_sync_executor.py b/src/itp_interface/tools/lean4_sync_executor.py index 22b18ba..9d8311f 100644 --- a/src/itp_interface/tools/lean4_sync_executor.py +++ b/src/itp_interface/tools/lean4_sync_executor.py @@ -831,6 +831,7 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat proof_running = 'sorries' in response or 'proofState' in response error_messages = response.get('message', None) goal_text = None + goal_texts = [] if error_messages is None and 'proofState' in response: error_messages = response.get('messages', None) elif error_messages is None: @@ -840,6 +841,7 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat text_msg = msg.get('data', None) if text_msg is not None and text_msg.startswith(Lean4SyncExecutor.unsolved_message): goal_text = text_msg[len(Lean4SyncExecutor.unsolved_message):] + goal_texts.append(goal_text) else: error_messages.append(msg) if len(error_messages) == 0: @@ -865,11 +867,11 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat if self._proof_running: proof_state_idx = None proof_goals = [] - if goal_text is not None: - if len(goal_text) == 0: - proof_goals = [] - else: - proof_goals = [goal_text] + if len(goal_texts) == 0: + proof_goals = [] + elif len(goal_texts) > 0: + proof_goals = [g_text for g_text in goal_texts + if g_text is not None and len(g_text) > 0] elif 'sorries' in response: sorries = response['sorries'] # TODO: Go over all the sorries and find the one which matches the line number with idx + 1 diff --git a/src/test/simple_data_gen_test.py b/src/test/simple_data_gen_test.py index 5d2fcc0..802f725 100644 --- a/src/test/simple_data_gen_test.py +++ b/src/test/simple_data_gen_test.py @@ -39,8 +39,12 @@ def test_proof_step_data_gen(self): dirs = sorted(os.listdir(".log/data_generation/benchmark/simple_benchmark_lean")) print(dirs) last_dir = dirs[-1] - train_data = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir, "train") - data_gen_file = os.path.join(train_data, "local_data_0000000025.json") + train_data = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir, "train") + list_files = os.listdir(train_data) + data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")] + assert len(data_files) == 1, f"No files found in the train directory. Expected one file. Found: {data_files}" + print(data_files[0]) + data_gen_file = os.path.join(train_data, data_files[0]) print("Data Gen File:", data_gen_file) with open(data_gen_file, "r") as f: print(f.read()) diff --git a/src/test/simple_env_test.py b/src/test/simple_env_test.py index 8ea2b8a..d6f3b1d 100644 --- a/src/test/simple_env_test.py +++ b/src/test/simple_env_test.py @@ -7,7 +7,7 @@ def __init__(self): def build_lean4_project(self, project_folder): import os # Build the project - with os.popen(f"cd {project_folder} && lake build") as proc: + with os.popen(f"cd {project_folder} && lake exe cache get && lake build") as proc: print("Building Lean4 project...") print('-'*15 + 'Build Logs' + '-'*15) print(proc.read()) @@ -488,6 +488,88 @@ def test_simple_lean4_done_test(self): print(goal.goal) print(f"="*30) + def test_simple_lean4_have_test(self): + from itp_interface.rl.proof_state import ProofState + from itp_interface.rl.proof_action import ProofAction + from itp_interface.rl.simple_proof_env import ProofEnv + from itp_interface.tools.proof_exec_callback import ProofExecutorCallback + from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy + project_folder = "src/data/test/lean4_proj" + file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" + # Build the project + # cd src/data/test/lean4_proj && lake build + helper = Helper() + helper.build_lean4_project(project_folder) + language = ProofAction.Language.LEAN4 + theorem_name = "imo_1959_p1" + # theorem test3 (p q : Prop) (hp : p) (hq : q) + # : p ∧ q ∧ p := + proof_exec_callback = ProofExecutorCallback( + project_folder=project_folder, + file_path=file_path, + language=language, + always_use_retrieval=False, + keep_local_context=True, + enforce_qed=True + ) + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK + env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms) + proof_steps = [ +'rw [Nat.gcd_rec]', +'rw [Nat.mod_eq_of_lt (by linarith)]', +'rw [Nat.gcd_rec]', +'rw [Nat.gcd_rec]', +'have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by', +' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring', +' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]', +' have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith', +' rw [Nat.mod_eq_of_lt]', +' rw [Nat.mod_eq_of_lt]', +' exact h₂', +' rw [Nat.mod_eq_of_lt]', +' exact h₂', +' exact h₂', +'rw [eq₂]' + ] + with env: + for proof_step in proof_steps: + state, _, next_state, _, done, info = env.step(ProofAction( + ProofAction.ActionType.RUN_TACTIC, + language, + tactics=[proof_step])) + if info.error_message is not None: + print(f"Error: {info.error_message}") + # This prints StateChanged, StateUnchanged, Failed, or Done + print(info.progress) + print('-'*30) + if done: + raise Exception("Proof should not have finished") + else: + s1 : ProofState = state + s2 : ProofState = next_state + print(f"Current Goal:") + print('-'*30) + for goal in s1.training_data_format.start_goals: + hyps = '\n'.join([hyp for hyp in goal.hypotheses]) + print(hyps) + print('|- ', end='') + print(goal.goal) + print(f'*'*30) + print(f"="*30) + print(f"Action: {proof_step}") + print(f"="*30) + print(f"Next Goal:") + print('-'*30) + for goal in s2.training_data_format.start_goals: + hyps = '\n'.join([hyp for hyp in goal.hypotheses]) + print(hyps) + print('|- ', end='') + print(goal.goal) + print(f'*'*30) + print(f"="*30) + print(f"DONE: {done}") + print('-'*30) def main(): unittest.main()