Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "itp_interface"
version = "1.1.9"
version = "1.1.10"
authors = [
{ name="Amitayush Thakur", email="amitayush@utexas.edu" },
]
Expand Down
23 changes: 23 additions & 0 deletions src/data/test/lean4_proj/Lean4Proj/Basic.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Mathlib
namespace Lean4Proj1

def hello := "world"
Expand Down Expand Up @@ -51,4 +52,26 @@ theorem test3 (p q : Prop) (hp : p) (hq : q)
exact hq
exact hp

theorem imo_1959_p1
(n : ℕ)
(h₀ : 0 < n) :
Nat.gcd (21*n + 4) (14*n + 3) = 1 := by
rw [Nat.gcd_rec]
rw [Nat.mod_eq_of_lt (by linarith)]
rw [Nat.gcd_rec]
rw [Nat.gcd_rec]
have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by
have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring
rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]
have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith
rw [Nat.mod_eq_of_lt]
rw [Nat.mod_eq_of_lt]
exact h₂
rw [Nat.mod_eq_of_lt]
exact h₂
exact h₂
rw [eq₂]
sorry


end Lean4Proj2
12 changes: 7 additions & 5 deletions src/itp_interface/tools/lean4_sync_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat
proof_running = 'sorries' in response or 'proofState' in response
error_messages = response.get('message', None)
goal_text = None
goal_texts = []
if error_messages is None and 'proofState' in response:
error_messages = response.get('messages', None)
elif error_messages is None:
Expand All @@ -840,6 +841,7 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat
text_msg = msg.get('data', None)
if text_msg is not None and text_msg.startswith(Lean4SyncExecutor.unsolved_message):
goal_text = text_msg[len(Lean4SyncExecutor.unsolved_message):]
goal_texts.append(goal_text)
else:
error_messages.append(msg)
if len(error_messages) == 0:
Expand All @@ -865,11 +867,11 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat
if self._proof_running:
proof_state_idx = None
proof_goals = []
if goal_text is not None:
if len(goal_text) == 0:
proof_goals = []
else:
proof_goals = [goal_text]
if len(goal_texts) == 0:
proof_goals = []
elif len(goal_texts) > 0:
proof_goals = [g_text for g_text in goal_texts
if g_text is not None and len(g_text) > 0]
elif 'sorries' in response:
sorries = response['sorries']
# TODO: Go over all the sorries and find the one which matches the line number with idx + 1
Expand Down
8 changes: 6 additions & 2 deletions src/test/simple_data_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ def test_proof_step_data_gen(self):
dirs = sorted(os.listdir(".log/data_generation/benchmark/simple_benchmark_lean"))
print(dirs)
last_dir = dirs[-1]
train_data = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir, "train")
data_gen_file = os.path.join(train_data, "local_data_0000000025.json")
train_data = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir, "train")
list_files = os.listdir(train_data)
data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")]
assert len(data_files) == 1, f"No files found in the train directory. Expected one file. Found: {data_files}"
print(data_files[0])
data_gen_file = os.path.join(train_data, data_files[0])
print("Data Gen File:", data_gen_file)
with open(data_gen_file, "r") as f:
print(f.read())
Expand Down
84 changes: 83 additions & 1 deletion src/test/simple_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(self):
def build_lean4_project(self, project_folder):
import os
# Build the project
with os.popen(f"cd {project_folder} && lake build") as proc:
with os.popen(f"cd {project_folder} && lake exe cache get && lake build") as proc:
print("Building Lean4 project...")
print('-'*15 + 'Build Logs' + '-'*15)
print(proc.read())
Expand Down Expand Up @@ -488,6 +488,88 @@ def test_simple_lean4_done_test(self):
print(goal.goal)
print(f"="*30)

def test_simple_lean4_have_test(self):
from itp_interface.rl.proof_state import ProofState
from itp_interface.rl.proof_action import ProofAction
from itp_interface.rl.simple_proof_env import ProofEnv
from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
project_folder = "src/data/test/lean4_proj"
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
helper = Helper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = "imo_1959_p1"
# theorem test3 (p q : Prop) (hp : p) (hq : q)
# : p ∧ q ∧ p :=
proof_exec_callback = ProofExecutorCallback(
project_folder=project_folder,
file_path=file_path,
language=language,
always_use_retrieval=False,
keep_local_context=True,
enforce_qed=True
)
always_retrieve_thms = False
retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
proof_steps = [
'rw [Nat.gcd_rec]',
'rw [Nat.mod_eq_of_lt (by linarith)]',
'rw [Nat.gcd_rec]',
'rw [Nat.gcd_rec]',
'have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by',
' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring',
' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]',
' have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith',
' rw [Nat.mod_eq_of_lt]',
' rw [Nat.mod_eq_of_lt]',
' exact h₂',
' rw [Nat.mod_eq_of_lt]',
' exact h₂',
' exact h₂',
'rw [eq₂]'
]
with env:
for proof_step in proof_steps:
state, _, next_state, _, done, info = env.step(ProofAction(
ProofAction.ActionType.RUN_TACTIC,
language,
tactics=[proof_step]))
if info.error_message is not None:
print(f"Error: {info.error_message}")
# This prints StateChanged, StateUnchanged, Failed, or Done
print(info.progress)
print('-'*30)
if done:
raise Exception("Proof should not have finished")
else:
s1 : ProofState = state
s2 : ProofState = next_state
print(f"Current Goal:")
print('-'*30)
for goal in s1.training_data_format.start_goals:
hyps = '\n'.join([hyp for hyp in goal.hypotheses])
print(hyps)
print('|- ', end='')
print(goal.goal)
print(f'*'*30)
print(f"="*30)
print(f"Action: {proof_step}")
print(f"="*30)
print(f"Next Goal:")
print('-'*30)
for goal in s2.training_data_format.start_goals:
hyps = '\n'.join([hyp for hyp in goal.hypotheses])
print(hyps)
print('|- ', end='')
print(goal.goal)
print(f'*'*30)
print(f"="*30)
print(f"DONE: {done}")
print('-'*30)

def main():
unittest.main()
Expand Down