Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ Generic interface for hooking up to any Interactive Theorem Prover (ITP) and col
pip install itp-interface
```

2. Run the following command to prepare the REPL for Lean 4. (The default version is 4.7.0-rc2. You can change the version by setting the `LEAN_VERSION` environment variable. If no version is set, then 4.7.0-rc2 is used.)
2. Run the following command to prepare the REPL for Lean 4. The default version is 4.7.0-rc2. You can change the version by setting the `LEAN_VERSION` environment variable. If no version is set, then 4.7.0-rc2 is used. However, the itp-interface supports up to Lean 4.17.0.
>NOTE: The Lean 4 version must match the version of the Lean 4 project you are working with.
```bash
export LEAN_VERSION="4.15.0"
export LEAN_VERSION="4.7.0-rc2"
install-lean-repl
# ^^ Change the LEAN_VERSION to the version of Lean 4 you are working with.
# ^^^ Example: export LEAN_VERSION="4.15.0" to use Lean 4.15.0
# itp-interface supports up to Lean 4.17.0
```

3. Run the following command to build the REPL for Lean 4. Make sure that `lean --version` returns the correct version before running the command below. If not then check if `$HOME/.elan/bin` is in your path. Recommended to run `source $HOME/.elan/env` before running the command below.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "itp_interface"
version = "1.1.7"
version = "1.1.8"
authors = [
{ name="Amitayush Thakur", email="amitayush@utexas.edu" },
]
Expand Down
2 changes: 1 addition & 1 deletion src/itp_interface/main/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def install_lean_repl():
elan_url = "https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh"
os.system(f"curl -sSL {elan_url} | sh")
print("[OK] .elan installed")
lean_repo = "leanprover-community/lean"
lean_repo = "leanprover/lean4"
# Create a temporary script to run
shell_code = f"""
source $HOME/.elan/env
Expand Down
122 changes: 111 additions & 11 deletions src/itp_interface/main/merge_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,93 @@
import typing
import copy
from itp_interface.tools.log_utils import setup_logger
from itp_interface.tools.training_data import TrainingData
from itp_interface.tools.training_data import TrainingData, TrainingDataFormat

def filter_training_data(training_data: TrainingData, max_distance_to_good: int = 10):

def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=None):
def _reconstruct_proof_tree(prev_state_id_map: typing.Dict[int, typing.Set[int]],
good_state_ids: typing.Set[int],
bad_state_ids: typing.Set[int],
distance_map: typing.Dict[int, int]):
distance = 1
while True:
# The idea is at some point no new good state ids will be added
# This is level order traversal but in reverse
new_good_state_ids = set()
for end_state, start_states in prev_state_id_map.items():
if end_state in good_state_ids:
for start_state in start_states:
if start_state not in good_state_ids:
# To avoid infinite loop we need to not add the start_state if it is already in good_state_ids
new_good_state_ids.add(start_state)
dist = distance_map.get(start_state, distance)
if dist >= distance:
distance_map[start_state] = distance
distance += 1
if len(new_good_state_ids) == 0:
break
good_state_ids.update(new_good_state_ids)
# Now identify the states which are not good
for end_state in prev_state_id_map.keys():
if end_state not in good_state_ids:
bad_state_ids.add(end_state)

def _reconstruct_prev_state_id_map(training_datas: typing.List[TrainingDataFormat]) -> typing.Tuple[typing.Dict[int, int], int]:
prev_state_id_map = {}
done_state = None
for training_data in training_datas:
if training_data.addition_state_info is not None and len(training_data.addition_state_info) > 0:
start_state = training_data.addition_state_info.get("start_goal_id", None)
end_state = training_data.addition_state_info.get("end_goal_id", None)
if start_state is not None and end_state is not None:
prev_to_end_state = prev_state_id_map.get(end_state, set())
prev_to_end_state.add(start_state)
prev_state_id_map[end_state] = prev_to_end_state
if done_state is None and training_data.addition_state_info.get("done", False):
done_state = end_state
return prev_state_id_map, done_state

filtered_training_data : typing.List[TrainingDataFormat] = []
proof_id_maps : typing.Dict[str, typing.List[TrainingDataFormat]] = {}
for idx in range(len(training_data)):
example = training_data[idx]
training_datas : typing.List[TrainingDataFormat] = proof_id_maps.get(example.proof_id, [])
training_datas.append(example)
proof_id_maps[example.proof_id] = training_datas
for proof_id, training_datas in proof_id_maps.items():
prev_state_id_map, done_state = _reconstruct_prev_state_id_map(training_datas)
# Now we have the prev_state_id_map and done_state
# Every state from where we can reach done_state is a should have a good value
# Every other state should have a bad value
# First figure out all state ids from where we can reach done_state
good_state_ids = set()
good_state_ids.add(done_state)
bad_state_ids = set()
distance_map = {done_state: 0}
_reconstruct_proof_tree(prev_state_id_map, good_state_ids, bad_state_ids, distance_map)
# Now we have the good_state_ids and bad_state_ids
# Now annotate the training data with the value function
for training_data in training_datas:
if training_data.addition_state_info is not None and len(training_data.addition_state_info) > 0:
end_state_id = training_data.addition_state_info.get("end_goal_id", None)
if end_state_id is not None:
progress = training_data.addition_state_info.get("progress", "")
if end_state_id in good_state_ids and (progress == "StateChanged" or progress == "Done"):
distance = distance_map.get(end_state_id, max_distance_to_good)
distance = min(distance, max_distance_to_good)
progress = f"[GOOD] [{distance}] {progress}"
else:
progress = f"[BAD] {progress}"
training_data.addition_state_info["progress"] = progress
if "GOOD" in progress:
filtered_training_data.append(training_data)
elif training_data.addition_state_info is None:
filtered_training_data.append(training_data)

return filtered_training_data


def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=None, should_filter_data=False):
"""
Merge datasets
"""
Expand Down Expand Up @@ -52,10 +135,23 @@ def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=No
for td in tds:
logger.info(f"Start loading {td.folder} ...")
td.load()
logger.info(f"Loaded {td.folder}.")
logger.info(f"Start merging {td.folder} ...")
merged_td.merge(td)
logger.info(f"Merged {td.folder}")

filtered_training_data_points : typing.List[TrainingDataFormat] = None
if should_filter_data and td.folder != datasets[0]:
# TODO: move to the right location
filtered_training_data_points = filter_training_data(td)

if filtered_training_data_points is not None:
logger.info(f"Filtered training data points to {len(filtered_training_data_points)}")
logger.info(f"Start merging {td.folder} after filtering...")
for data in filtered_training_data_points:
merged_td.merge(data)
else:
logger.info(f"Loaded {td.folder}.")
logger.info(f"Start merging {td.folder} ...")
merged_td.merge(td)
logger.info(f"Merged {td.folder}")

logger.info("Finished merging datasets.")
logger.info(f"Start saving merged dataset to {output} ...")
merged_td.save()
Expand All @@ -68,11 +164,15 @@ def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=No
max_parallelism=max_parallelism,
logger=logger
)
new_merged_td.load()
assert len(new_merged_td) == metadata.total_proof_step_cnt, "Merged dataset is not correct"
assert new_merged_td.meta.last_training_data == metadata.last_training_data, "Merged dataset is not correct"
assert new_merged_td.meta.last_proof_id == metadata.last_proof_id, "Merged dataset is not correct"
logger.info("Merged dataset is correct.")
if not should_filter_data:
new_merged_td.load()
assert len(new_merged_td) == metadata.total_proof_step_cnt, "Merged dataset is not correct"
assert new_merged_td.meta.last_training_data == metadata.last_training_data, "Merged dataset is not correct"
assert new_merged_td.meta.last_proof_id == metadata.last_proof_id, "Merged dataset is not correct"
logger.info("Merged dataset is correct.")
else:
logger.info("Filtered and merged data, skipping verification")
logger.info("Finished verifying the merged dataset.")
pass


Expand Down